mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-29 05:36:13 +01:00
Merge branch 'dev' into feat/prompt-enhancement
This commit is contained in:
commit
3d261a969d
365 changed files with 23826 additions and 8790 deletions
|
|
@ -5,6 +5,7 @@ export default {
|
|||
testResultsProcessor: 'jest-junit',
|
||||
moduleNameMapper: {
|
||||
'^@src/(.*)$': '<rootDir>/src/$1',
|
||||
'~/(.*)': '<rootDir>/src/$1',
|
||||
},
|
||||
// coverageThreshold: {
|
||||
// global: {
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "librechat-mcp",
|
||||
"version": "1.2.2",
|
||||
"name": "@librechat/api",
|
||||
"version": "1.2.4",
|
||||
"type": "commonjs",
|
||||
"description": "MCP services for LibreChat",
|
||||
"main": "dist/index.js",
|
||||
|
|
@ -47,9 +47,11 @@
|
|||
"@rollup/plugin-replace": "^5.0.5",
|
||||
"@rollup/plugin-terser": "^0.4.4",
|
||||
"@rollup/plugin-typescript": "^12.1.2",
|
||||
"@types/bun": "^1.2.15",
|
||||
"@types/diff": "^6.0.0",
|
||||
"@types/express": "^5.0.0",
|
||||
"@types/jest": "^29.5.2",
|
||||
"@types/multer": "^1.4.13",
|
||||
"@types/node": "^20.3.0",
|
||||
"@types/react": "^18.2.18",
|
||||
"@types/winston": "^2.4.4",
|
||||
|
|
@ -66,13 +68,19 @@
|
|||
"publishConfig": {
|
||||
"registry": "https://registry.npmjs.org/"
|
||||
},
|
||||
"dependencies": {
|
||||
"@modelcontextprotocol/sdk": "^1.11.2",
|
||||
"peerDependencies": {
|
||||
"@librechat/agents": "^2.4.41",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@modelcontextprotocol/sdk": "^1.12.3",
|
||||
"axios": "^1.8.2",
|
||||
"diff": "^7.0.0",
|
||||
"eventsource": "^3.0.2",
|
||||
"express": "^4.21.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"keyv": "^5.3.2"
|
||||
"express": "^4.21.2",
|
||||
"keyv": "^5.3.2",
|
||||
"librechat-data-provider": "*",
|
||||
"node-fetch": "2.7.0",
|
||||
"tiktoken": "^1.0.15",
|
||||
"undici": "^7.10.0",
|
||||
"zod": "^3.22.4"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
// rollup.config.js
|
||||
import { readFileSync } from 'fs';
|
||||
import json from '@rollup/plugin-json';
|
||||
import terser from '@rollup/plugin-terser';
|
||||
import replace from '@rollup/plugin-replace';
|
||||
import commonjs from '@rollup/plugin-commonjs';
|
||||
|
|
@ -29,15 +30,17 @@ const plugins = [
|
|||
inlineSourceMap: true,
|
||||
}),
|
||||
terser(),
|
||||
json(),
|
||||
];
|
||||
|
||||
const cjsBuild = {
|
||||
input: 'src/index.ts',
|
||||
output: {
|
||||
file: pkg.main,
|
||||
dir: 'dist',
|
||||
format: 'cjs',
|
||||
sourcemap: true,
|
||||
exports: 'named',
|
||||
entryFileNames: '[name].js',
|
||||
},
|
||||
external: [...Object.keys(pkg.dependencies || {}), ...Object.keys(pkg.devDependencies || {})],
|
||||
preserveSymlinks: true,
|
||||
93
packages/api/src/agents/auth.ts
Normal file
93
packages/api/src/agents/auth.ts
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type { IPluginAuth, PluginAuthMethods } from '@librechat/data-schemas';
|
||||
import { decrypt } from '../crypto/encryption';
|
||||
|
||||
export interface GetPluginAuthMapParams {
|
||||
userId: string;
|
||||
pluginKeys: string[];
|
||||
throwError?: boolean;
|
||||
findPluginAuthsByKeys: PluginAuthMethods['findPluginAuthsByKeys'];
|
||||
}
|
||||
|
||||
export type PluginAuthMap = Record<string, Record<string, string>>;
|
||||
|
||||
/**
|
||||
* Retrieves and decrypts authentication values for multiple plugins
|
||||
* @returns A map where keys are pluginKeys and values are objects of authField:decryptedValue pairs
|
||||
*/
|
||||
export async function getPluginAuthMap({
|
||||
userId,
|
||||
pluginKeys,
|
||||
throwError = true,
|
||||
findPluginAuthsByKeys,
|
||||
}: GetPluginAuthMapParams): Promise<PluginAuthMap> {
|
||||
try {
|
||||
/** Early return for empty plugin keys */
|
||||
if (!pluginKeys?.length) {
|
||||
return {};
|
||||
}
|
||||
|
||||
/** All plugin auths for current user query */
|
||||
const pluginAuths = await findPluginAuthsByKeys({ userId, pluginKeys });
|
||||
|
||||
/** Group auth records by pluginKey for efficient lookup */
|
||||
const authsByPlugin = new Map<string, IPluginAuth[]>();
|
||||
for (const auth of pluginAuths) {
|
||||
if (!auth.pluginKey) {
|
||||
logger.warn(`[getPluginAuthMap] Missing pluginKey for userId ${userId}`);
|
||||
continue;
|
||||
}
|
||||
const existing = authsByPlugin.get(auth.pluginKey) || [];
|
||||
existing.push(auth);
|
||||
authsByPlugin.set(auth.pluginKey, existing);
|
||||
}
|
||||
|
||||
const authMap: PluginAuthMap = {};
|
||||
const decryptionPromises: Promise<void>[] = [];
|
||||
|
||||
/** Single loop through requested pluginKeys */
|
||||
for (const pluginKey of pluginKeys) {
|
||||
authMap[pluginKey] = {};
|
||||
const auths = authsByPlugin.get(pluginKey) || [];
|
||||
|
||||
for (const auth of auths) {
|
||||
decryptionPromises.push(
|
||||
(async () => {
|
||||
try {
|
||||
const decryptedValue = await decrypt(auth.value);
|
||||
authMap[pluginKey][auth.authField] = decryptedValue;
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : 'Unknown error';
|
||||
logger.error(
|
||||
`[getPluginAuthMap] Decryption failed for userId ${userId}, plugin ${pluginKey}, field ${auth.authField}: ${message}`,
|
||||
);
|
||||
|
||||
if (throwError) {
|
||||
throw new Error(
|
||||
`Decryption failed for plugin ${pluginKey}, field ${auth.authField}: ${message}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
})(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
await Promise.all(decryptionPromises);
|
||||
return authMap;
|
||||
} catch (error) {
|
||||
if (!throwError) {
|
||||
/** Empty objects for each plugin key on error */
|
||||
return pluginKeys.reduce((acc, key) => {
|
||||
acc[key] = {};
|
||||
return acc;
|
||||
}, {} as PluginAuthMap);
|
||||
}
|
||||
|
||||
const message = error instanceof Error ? error.message : 'Unknown error';
|
||||
logger.error(
|
||||
`[getPluginAuthMap] Failed to fetch auth values for userId ${userId}, plugins: ${pluginKeys.join(', ')}: ${message}`,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
24
packages/api/src/agents/config.ts
Normal file
24
packages/api/src/agents/config.ts
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import { EModelEndpoint, agentsEndpointSchema } from 'librechat-data-provider';
|
||||
import type { TCustomConfig, TAgentsEndpoint } from 'librechat-data-provider';
|
||||
|
||||
/**
|
||||
* Sets up the Agents configuration from the config (`librechat.yaml`) file.
|
||||
* If no agents config is defined, uses the provided defaults or parses empty object.
|
||||
*
|
||||
* @param config - The loaded custom configuration.
|
||||
* @param [defaultConfig] - Default configuration from getConfigDefaults.
|
||||
* @returns The Agents endpoint configuration.
|
||||
*/
|
||||
export function agentsConfigSetup(
|
||||
config: TCustomConfig,
|
||||
defaultConfig: Partial<TAgentsEndpoint>,
|
||||
): Partial<TAgentsEndpoint> {
|
||||
const agentsConfig = config?.endpoints?.[EModelEndpoint.agents];
|
||||
|
||||
if (!agentsConfig) {
|
||||
return defaultConfig || agentsEndpointSchema.parse({});
|
||||
}
|
||||
|
||||
const parsedConfig = agentsEndpointSchema.parse(agentsConfig);
|
||||
return parsedConfig;
|
||||
}
|
||||
4
packages/api/src/agents/index.ts
Normal file
4
packages/api/src/agents/index.ts
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
export * from './config';
|
||||
export * from './memory';
|
||||
export * from './resources';
|
||||
export * from './run';
|
||||
468
packages/api/src/agents/memory.ts
Normal file
468
packages/api/src/agents/memory.ts
Normal file
|
|
@ -0,0 +1,468 @@
|
|||
/** Memories */
|
||||
import { z } from 'zod';
|
||||
import { tool } from '@langchain/core/tools';
|
||||
import { Tools } from 'librechat-data-provider';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { Run, Providers, GraphEvents } from '@librechat/agents';
|
||||
import type {
|
||||
StreamEventData,
|
||||
ToolEndCallback,
|
||||
EventHandler,
|
||||
ToolEndData,
|
||||
LLMConfig,
|
||||
} from '@librechat/agents';
|
||||
import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
|
||||
import type { ObjectId, MemoryMethods } from '@librechat/data-schemas';
|
||||
import type { BaseMessage } from '@langchain/core/messages';
|
||||
import type { Response as ServerResponse } from 'express';
|
||||
import { Tokenizer } from '~/utils';
|
||||
|
||||
type RequiredMemoryMethods = Pick<
|
||||
MemoryMethods,
|
||||
'setMemory' | 'deleteMemory' | 'getFormattedMemories'
|
||||
>;
|
||||
|
||||
type ToolEndMetadata = Record<string, unknown> & {
|
||||
run_id?: string;
|
||||
thread_id?: string;
|
||||
};
|
||||
|
||||
export interface MemoryConfig {
|
||||
validKeys?: string[];
|
||||
instructions?: string;
|
||||
llmConfig?: Partial<LLMConfig>;
|
||||
tokenLimit?: number;
|
||||
}
|
||||
|
||||
export const memoryInstructions =
|
||||
'The system automatically stores important user information and can update or delete memories based on user requests, enabling dynamic memory management.';
|
||||
|
||||
const getDefaultInstructions = (
|
||||
validKeys?: string[],
|
||||
tokenLimit?: number,
|
||||
) => `Use the \`set_memory\` tool to save important information about the user, but ONLY when the user has explicitly provided this information. If there is nothing to note about the user specifically, END THE TURN IMMEDIATELY.
|
||||
|
||||
The \`delete_memory\` tool should only be used in two scenarios:
|
||||
1. When the user explicitly asks to forget or remove specific information
|
||||
2. When updating existing memories, use the \`set_memory\` tool instead of deleting and re-adding the memory.
|
||||
|
||||
${
|
||||
validKeys && validKeys.length > 0
|
||||
? `CRITICAL INSTRUCTION: Only the following keys are valid for storing memories:
|
||||
${validKeys.map((key) => `- ${key}`).join('\n ')}`
|
||||
: 'You can use any appropriate key to store memories about the user.'
|
||||
}
|
||||
|
||||
${
|
||||
tokenLimit
|
||||
? `⚠️ TOKEN LIMIT: Each memory value must not exceed ${tokenLimit} tokens. Be concise and store only essential information.`
|
||||
: ''
|
||||
}
|
||||
|
||||
⚠️ WARNING ⚠️
|
||||
DO NOT STORE ANY INFORMATION UNLESS THE USER HAS EXPLICITLY PROVIDED IT.
|
||||
ONLY store information the user has EXPLICITLY shared.
|
||||
NEVER guess or assume user information.
|
||||
ALL memory values must be factual statements about THIS specific user.
|
||||
If nothing needs to be stored, DO NOT CALL any memory tools.
|
||||
If you're unsure whether to store something, DO NOT store it.
|
||||
If nothing needs to be stored, END THE TURN IMMEDIATELY.`;
|
||||
|
||||
/**
|
||||
* Creates a memory tool instance with user context
|
||||
*/
|
||||
const createMemoryTool = ({
|
||||
userId,
|
||||
setMemory,
|
||||
validKeys,
|
||||
tokenLimit,
|
||||
totalTokens = 0,
|
||||
}: {
|
||||
userId: string | ObjectId;
|
||||
setMemory: MemoryMethods['setMemory'];
|
||||
validKeys?: string[];
|
||||
tokenLimit?: number;
|
||||
totalTokens?: number;
|
||||
}) => {
|
||||
return tool(
|
||||
async ({ key, value }) => {
|
||||
try {
|
||||
if (validKeys && validKeys.length > 0 && !validKeys.includes(key)) {
|
||||
logger.warn(
|
||||
`Memory Agent failed to set memory: Invalid key "${key}". Must be one of: ${validKeys.join(
|
||||
', ',
|
||||
)}`,
|
||||
);
|
||||
return `Invalid key "${key}". Must be one of: ${validKeys.join(', ')}`;
|
||||
}
|
||||
|
||||
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
|
||||
|
||||
if (tokenLimit && tokenCount > tokenLimit) {
|
||||
logger.warn(
|
||||
`Memory Agent failed to set memory: Value exceeds token limit. Value has ${tokenCount} tokens, but limit is ${tokenLimit}`,
|
||||
);
|
||||
return `Memory value too large: ${tokenCount} tokens exceeds limit of ${tokenLimit}`;
|
||||
}
|
||||
|
||||
if (tokenLimit && totalTokens + tokenCount > tokenLimit) {
|
||||
const remainingCapacity = tokenLimit - totalTokens;
|
||||
logger.warn(
|
||||
`Memory Agent failed to set memory: Would exceed total token limit. Current usage: ${totalTokens}, new memory: ${tokenCount} tokens, limit: ${tokenLimit}`,
|
||||
);
|
||||
return `Cannot add memory: would exceed token limit. Current usage: ${totalTokens}/${tokenLimit} tokens. This memory requires ${tokenCount} tokens, but only ${remainingCapacity} tokens available.`;
|
||||
}
|
||||
|
||||
const artifact: Record<Tools.memory, MemoryArtifact> = {
|
||||
[Tools.memory]: {
|
||||
key,
|
||||
value,
|
||||
tokenCount,
|
||||
type: 'update',
|
||||
},
|
||||
};
|
||||
|
||||
const result = await setMemory({ userId, key, value, tokenCount });
|
||||
if (result.ok) {
|
||||
logger.debug(`Memory set for key "${key}" (${tokenCount} tokens) for user "${userId}"`);
|
||||
return [`Memory set for key "${key}" (${tokenCount} tokens)`, artifact];
|
||||
}
|
||||
logger.warn(`Failed to set memory for key "${key}" for user "${userId}"`);
|
||||
return [`Failed to set memory for key "${key}"`, undefined];
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to set memory', error);
|
||||
return [`Error setting memory for key "${key}"`, undefined];
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'set_memory',
|
||||
description: 'Saves important information about the user into memory.',
|
||||
responseFormat: 'content_and_artifact',
|
||||
schema: z.object({
|
||||
key: z
|
||||
.string()
|
||||
.describe(
|
||||
validKeys && validKeys.length > 0
|
||||
? `The key of the memory value. Must be one of: ${validKeys.join(', ')}`
|
||||
: 'The key identifier for this memory',
|
||||
),
|
||||
value: z
|
||||
.string()
|
||||
.describe(
|
||||
'Value MUST be a complete sentence that fully describes relevant user information.',
|
||||
),
|
||||
}),
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a delete memory tool instance with user context
|
||||
*/
|
||||
const createDeleteMemoryTool = ({
|
||||
userId,
|
||||
deleteMemory,
|
||||
validKeys,
|
||||
}: {
|
||||
userId: string | ObjectId;
|
||||
deleteMemory: MemoryMethods['deleteMemory'];
|
||||
validKeys?: string[];
|
||||
}) => {
|
||||
return tool(
|
||||
async ({ key }) => {
|
||||
try {
|
||||
if (validKeys && validKeys.length > 0 && !validKeys.includes(key)) {
|
||||
logger.warn(
|
||||
`Memory Agent failed to delete memory: Invalid key "${key}". Must be one of: ${validKeys.join(
|
||||
', ',
|
||||
)}`,
|
||||
);
|
||||
return `Invalid key "${key}". Must be one of: ${validKeys.join(', ')}`;
|
||||
}
|
||||
|
||||
const artifact: Record<Tools.memory, MemoryArtifact> = {
|
||||
[Tools.memory]: {
|
||||
key,
|
||||
type: 'delete',
|
||||
},
|
||||
};
|
||||
|
||||
const result = await deleteMemory({ userId, key });
|
||||
if (result.ok) {
|
||||
logger.debug(`Memory deleted for key "${key}" for user "${userId}"`);
|
||||
return [`Memory deleted for key "${key}"`, artifact];
|
||||
}
|
||||
logger.warn(`Failed to delete memory for key "${key}" for user "${userId}"`);
|
||||
return [`Failed to delete memory for key "${key}"`, undefined];
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to delete memory', error);
|
||||
return [`Error deleting memory for key "${key}"`, undefined];
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'delete_memory',
|
||||
description:
|
||||
'Deletes specific memory data about the user using the provided key. For updating existing memories, use the `set_memory` tool instead',
|
||||
responseFormat: 'content_and_artifact',
|
||||
schema: z.object({
|
||||
key: z
|
||||
.string()
|
||||
.describe(
|
||||
validKeys && validKeys.length > 0
|
||||
? `The key of the memory to delete. Must be one of: ${validKeys.join(', ')}`
|
||||
: 'The key identifier of the memory to delete',
|
||||
),
|
||||
}),
|
||||
},
|
||||
);
|
||||
};
|
||||
export class BasicToolEndHandler implements EventHandler {
|
||||
private callback?: ToolEndCallback;
|
||||
constructor(callback?: ToolEndCallback) {
|
||||
this.callback = callback;
|
||||
}
|
||||
handle(
|
||||
event: string,
|
||||
data: StreamEventData | undefined,
|
||||
metadata?: Record<string, unknown>,
|
||||
): void {
|
||||
if (!metadata) {
|
||||
console.warn(`Graph or metadata not found in ${event} event`);
|
||||
return;
|
||||
}
|
||||
const toolEndData = data as ToolEndData | undefined;
|
||||
if (!toolEndData?.output) {
|
||||
console.warn('No output found in tool_end event');
|
||||
return;
|
||||
}
|
||||
this.callback?.(toolEndData, metadata);
|
||||
}
|
||||
}
|
||||
|
||||
export async function processMemory({
|
||||
res,
|
||||
userId,
|
||||
setMemory,
|
||||
deleteMemory,
|
||||
messages,
|
||||
memory,
|
||||
messageId,
|
||||
conversationId,
|
||||
validKeys,
|
||||
instructions,
|
||||
llmConfig,
|
||||
tokenLimit,
|
||||
totalTokens = 0,
|
||||
}: {
|
||||
res: ServerResponse;
|
||||
setMemory: MemoryMethods['setMemory'];
|
||||
deleteMemory: MemoryMethods['deleteMemory'];
|
||||
userId: string | ObjectId;
|
||||
memory: string;
|
||||
messageId: string;
|
||||
conversationId: string;
|
||||
messages: BaseMessage[];
|
||||
validKeys?: string[];
|
||||
instructions: string;
|
||||
tokenLimit?: number;
|
||||
totalTokens?: number;
|
||||
llmConfig?: Partial<LLMConfig>;
|
||||
}): Promise<(TAttachment | null)[] | undefined> {
|
||||
try {
|
||||
const memoryTool = createMemoryTool({ userId, tokenLimit, setMemory, validKeys, totalTokens });
|
||||
const deleteMemoryTool = createDeleteMemoryTool({
|
||||
userId,
|
||||
validKeys,
|
||||
deleteMemory,
|
||||
});
|
||||
|
||||
const currentMemoryTokens = totalTokens;
|
||||
|
||||
let memoryStatus = `# Existing memory:\n${memory ?? 'No existing memories'}`;
|
||||
|
||||
if (tokenLimit) {
|
||||
const remainingTokens = tokenLimit - currentMemoryTokens;
|
||||
memoryStatus = `# Memory Status:
|
||||
Current memory usage: ${currentMemoryTokens} tokens
|
||||
Token limit: ${tokenLimit} tokens
|
||||
Remaining capacity: ${remainingTokens} tokens
|
||||
|
||||
# Existing memory:
|
||||
${memory ?? 'No existing memories'}`;
|
||||
}
|
||||
|
||||
const defaultLLMConfig: LLMConfig = {
|
||||
provider: Providers.OPENAI,
|
||||
model: 'gpt-4.1-mini',
|
||||
temperature: 0.4,
|
||||
streaming: false,
|
||||
disableStreaming: true,
|
||||
};
|
||||
|
||||
const finalLLMConfig = {
|
||||
...defaultLLMConfig,
|
||||
...llmConfig,
|
||||
/**
|
||||
* Ensure streaming is always disabled for memory processing
|
||||
*/
|
||||
streaming: false,
|
||||
disableStreaming: true,
|
||||
};
|
||||
|
||||
const artifactPromises: Promise<TAttachment | null>[] = [];
|
||||
const memoryCallback = createMemoryCallback({ res, artifactPromises });
|
||||
const customHandlers = {
|
||||
[GraphEvents.TOOL_END]: new BasicToolEndHandler(memoryCallback),
|
||||
};
|
||||
|
||||
const run = await Run.create({
|
||||
runId: messageId,
|
||||
graphConfig: {
|
||||
type: 'standard',
|
||||
llmConfig: finalLLMConfig,
|
||||
tools: [memoryTool, deleteMemoryTool],
|
||||
instructions,
|
||||
additional_instructions: memoryStatus,
|
||||
toolEnd: true,
|
||||
},
|
||||
customHandlers,
|
||||
returnContent: true,
|
||||
});
|
||||
|
||||
const config = {
|
||||
configurable: {
|
||||
provider: llmConfig?.provider,
|
||||
thread_id: `memory-run-${conversationId}`,
|
||||
},
|
||||
streamMode: 'values',
|
||||
version: 'v2',
|
||||
} as const;
|
||||
|
||||
const inputs = {
|
||||
messages,
|
||||
};
|
||||
const content = await run.processStream(inputs, config);
|
||||
if (content) {
|
||||
logger.debug('Memory Agent processed memory successfully', content);
|
||||
} else {
|
||||
logger.warn('Memory Agent processed memory but returned no content');
|
||||
}
|
||||
return await Promise.all(artifactPromises);
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to process memory', error);
|
||||
}
|
||||
}
|
||||
|
||||
export async function createMemoryProcessor({
|
||||
res,
|
||||
userId,
|
||||
messageId,
|
||||
memoryMethods,
|
||||
conversationId,
|
||||
config = {},
|
||||
}: {
|
||||
res: ServerResponse;
|
||||
messageId: string;
|
||||
conversationId: string;
|
||||
userId: string | ObjectId;
|
||||
memoryMethods: RequiredMemoryMethods;
|
||||
config?: MemoryConfig;
|
||||
}): Promise<[string, (messages: BaseMessage[]) => Promise<(TAttachment | null)[] | undefined>]> {
|
||||
const { validKeys, instructions, llmConfig, tokenLimit } = config;
|
||||
const finalInstructions = instructions || getDefaultInstructions(validKeys, tokenLimit);
|
||||
|
||||
const { withKeys, withoutKeys, totalTokens } = await memoryMethods.getFormattedMemories({
|
||||
userId,
|
||||
});
|
||||
|
||||
return [
|
||||
withoutKeys,
|
||||
async function (messages: BaseMessage[]): Promise<(TAttachment | null)[] | undefined> {
|
||||
try {
|
||||
return await processMemory({
|
||||
res,
|
||||
userId,
|
||||
messages,
|
||||
validKeys,
|
||||
llmConfig,
|
||||
messageId,
|
||||
tokenLimit,
|
||||
conversationId,
|
||||
memory: withKeys,
|
||||
totalTokens: totalTokens || 0,
|
||||
instructions: finalInstructions,
|
||||
setMemory: memoryMethods.setMemory,
|
||||
deleteMemory: memoryMethods.deleteMemory,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to process memory', error);
|
||||
}
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
async function handleMemoryArtifact({
|
||||
res,
|
||||
data,
|
||||
metadata,
|
||||
}: {
|
||||
res: ServerResponse;
|
||||
data: ToolEndData;
|
||||
metadata?: ToolEndMetadata;
|
||||
}) {
|
||||
const output = data?.output;
|
||||
if (!output) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!output.artifact) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const memoryArtifact = output.artifact[Tools.memory] as MemoryArtifact | undefined;
|
||||
if (!memoryArtifact) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const attachment: Partial<TAttachment> = {
|
||||
type: Tools.memory,
|
||||
toolCallId: output.tool_call_id,
|
||||
messageId: metadata?.run_id ?? '',
|
||||
conversationId: metadata?.thread_id ?? '',
|
||||
[Tools.memory]: memoryArtifact,
|
||||
};
|
||||
if (!res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
return attachment;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a memory callback for handling memory artifacts
|
||||
* @param params - The parameters object
|
||||
* @param params.res - The server response object
|
||||
* @param params.artifactPromises - Array to collect artifact promises
|
||||
* @returns The memory callback function
|
||||
*/
|
||||
export function createMemoryCallback({
|
||||
res,
|
||||
artifactPromises,
|
||||
}: {
|
||||
res: ServerResponse;
|
||||
artifactPromises: Promise<Partial<TAttachment> | null>[];
|
||||
}): ToolEndCallback {
|
||||
return async (data: ToolEndData, metadata?: Record<string, unknown>) => {
|
||||
const output = data?.output;
|
||||
const memoryArtifact = output?.artifact?.[Tools.memory] as MemoryArtifact;
|
||||
if (memoryArtifact == null) {
|
||||
return;
|
||||
}
|
||||
artifactPromises.push(
|
||||
handleMemoryArtifact({ res, data, metadata }).catch((error) => {
|
||||
logger.error('Error processing memory artifact content:', error);
|
||||
return null;
|
||||
}),
|
||||
);
|
||||
};
|
||||
}
|
||||
990
packages/api/src/agents/resources.test.ts
Normal file
990
packages/api/src/agents/resources.test.ts
Normal file
|
|
@ -0,0 +1,990 @@
|
|||
import { primeResources } from './resources';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { EModelEndpoint, EToolResources, AgentCapabilities } from 'librechat-data-provider';
|
||||
import type { Request as ServerRequest } from 'express';
|
||||
import type { TFile } from 'librechat-data-provider';
|
||||
import type { TGetFiles } from './resources';
|
||||
|
||||
// Mock logger
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('primeResources', () => {
|
||||
let mockReq: ServerRequest;
|
||||
let mockGetFiles: jest.MockedFunction<TGetFiles>;
|
||||
let requestFileSet: Set<string>;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Setup mock request
|
||||
mockReq = {
|
||||
app: {
|
||||
locals: {
|
||||
[EModelEndpoint.agents]: {
|
||||
capabilities: [AgentCapabilities.ocr],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as ServerRequest;
|
||||
|
||||
// Setup mock getFiles function
|
||||
mockGetFiles = jest.fn();
|
||||
|
||||
// Setup request file set
|
||||
requestFileSet = new Set(['file1', 'file2', 'file3']);
|
||||
});
|
||||
|
||||
describe('when OCR is enabled and tool_resources has OCR file_ids', () => {
|
||||
it('should fetch OCR files and include them in attachments', async () => {
|
||||
const mockOcrFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'ocr-file-1',
|
||||
filename: 'document.pdf',
|
||||
filepath: '/uploads/document.pdf',
|
||||
object: 'file',
|
||||
type: 'application/pdf',
|
||||
bytes: 1024,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
mockGetFiles.mockResolvedValue(mockOcrFiles);
|
||||
|
||||
const tool_resources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['ocr-file-1'],
|
||||
},
|
||||
};
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments: undefined,
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
expect(mockGetFiles).toHaveBeenCalledWith({ file_id: { $in: ['ocr-file-1'] } }, {}, {});
|
||||
expect(result.attachments).toEqual(mockOcrFiles);
|
||||
expect(result.tool_resources).toEqual(tool_resources);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when OCR is disabled', () => {
|
||||
it('should not fetch OCR files even if tool_resources has OCR file_ids', async () => {
|
||||
(mockReq.app as ServerRequest['app']).locals[EModelEndpoint.agents].capabilities = [];
|
||||
|
||||
const tool_resources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['ocr-file-1'],
|
||||
},
|
||||
};
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments: undefined,
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
expect(mockGetFiles).not.toHaveBeenCalled();
|
||||
expect(result.attachments).toBeUndefined();
|
||||
expect(result.tool_resources).toEqual(tool_resources);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when attachments are provided', () => {
|
||||
it('should process files with fileIdentifier as execute_code resources', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'script.py',
|
||||
filepath: '/uploads/script.py',
|
||||
object: 'file',
|
||||
type: 'text/x-python',
|
||||
bytes: 512,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
metadata: {
|
||||
fileIdentifier: 'python-script',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources?.[EToolResources.execute_code]?.files).toEqual(mockFiles);
|
||||
});
|
||||
|
||||
it('should process embedded files as file_search resources', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file2',
|
||||
filename: 'document.txt',
|
||||
filepath: '/uploads/document.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 256,
|
||||
embedded: true,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources?.[EToolResources.file_search]?.files).toEqual(mockFiles);
|
||||
});
|
||||
|
||||
it('should process image files in requestFileSet as image_edit resources', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'image.png',
|
||||
filepath: '/uploads/image.png',
|
||||
object: 'file',
|
||||
type: 'image/png',
|
||||
bytes: 2048,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
height: 800,
|
||||
width: 600,
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources?.[EToolResources.image_edit]?.files).toEqual(mockFiles);
|
||||
});
|
||||
|
||||
it('should not process image files not in requestFileSet', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file-not-in-set',
|
||||
filename: 'image.png',
|
||||
filepath: '/uploads/image.png',
|
||||
object: 'file',
|
||||
type: 'image/png',
|
||||
bytes: 2048,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
height: 800,
|
||||
width: 600,
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources?.[EToolResources.image_edit]).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should not process image files without height and width', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'image.png',
|
||||
filepath: '/uploads/image.png',
|
||||
object: 'file',
|
||||
type: 'image/png',
|
||||
bytes: 2048,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
// Missing height and width
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources?.[EToolResources.image_edit]).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should filter out null files from attachments', async () => {
|
||||
const mockFiles: Array<TFile | null> = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'valid.txt',
|
||||
filepath: '/uploads/valid.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
null,
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file2',
|
||||
filename: 'valid2.txt',
|
||||
filepath: '/uploads/valid2.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 128,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toHaveLength(2);
|
||||
expect(result.attachments?.[0]?.file_id).toBe('file1');
|
||||
expect(result.attachments?.[1]?.file_id).toBe('file2');
|
||||
});
|
||||
|
||||
it('should merge existing tool_resources with new files', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'script.py',
|
||||
filepath: '/uploads/script.py',
|
||||
object: 'file',
|
||||
type: 'text/x-python',
|
||||
bytes: 512,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
metadata: {
|
||||
fileIdentifier: 'python-script',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const existingToolResources = {
|
||||
[EToolResources.execute_code]: {
|
||||
files: [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'existing-file',
|
||||
filename: 'existing.py',
|
||||
filepath: '/uploads/existing.py',
|
||||
object: 'file' as const,
|
||||
type: 'text/x-python',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: existingToolResources,
|
||||
});
|
||||
|
||||
expect(result.tool_resources?.[EToolResources.execute_code]?.files).toHaveLength(2);
|
||||
expect(result.tool_resources?.[EToolResources.execute_code]?.files?.[0]?.file_id).toBe(
|
||||
'existing-file',
|
||||
);
|
||||
expect(result.tool_resources?.[EToolResources.execute_code]?.files?.[1]?.file_id).toBe(
|
||||
'file1',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when both OCR and attachments are provided', () => {
|
||||
it('should include both OCR files and attachment files', async () => {
|
||||
const mockOcrFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'ocr-file-1',
|
||||
filename: 'document.pdf',
|
||||
filepath: '/uploads/document.pdf',
|
||||
object: 'file',
|
||||
type: 'application/pdf',
|
||||
bytes: 1024,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
const mockAttachmentFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'attachment.txt',
|
||||
filepath: '/uploads/attachment.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
mockGetFiles.mockResolvedValue(mockOcrFiles);
|
||||
const attachments = Promise.resolve(mockAttachmentFiles);
|
||||
|
||||
const tool_resources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['ocr-file-1'],
|
||||
},
|
||||
};
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
expect(result.attachments).toHaveLength(2);
|
||||
expect(result.attachments?.[0]?.file_id).toBe('ocr-file-1');
|
||||
expect(result.attachments?.[1]?.file_id).toBe('file1');
|
||||
});
|
||||
|
||||
it('should prevent duplicate files when same file exists in OCR and attachments', async () => {
|
||||
const sharedFile: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'shared-file-id',
|
||||
filename: 'document.pdf',
|
||||
filepath: '/uploads/document.pdf',
|
||||
object: 'file',
|
||||
type: 'application/pdf',
|
||||
bytes: 1024,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
};
|
||||
|
||||
const mockOcrFiles: TFile[] = [sharedFile];
|
||||
const mockAttachmentFiles: TFile[] = [
|
||||
sharedFile,
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'unique-file',
|
||||
filename: 'other.txt',
|
||||
filepath: '/uploads/other.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
mockGetFiles.mockResolvedValue(mockOcrFiles);
|
||||
const attachments = Promise.resolve(mockAttachmentFiles);
|
||||
|
||||
const tool_resources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['shared-file-id'],
|
||||
},
|
||||
};
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
// Should only have 2 files, not 3 (no duplicate)
|
||||
expect(result.attachments).toHaveLength(2);
|
||||
expect(result.attachments?.filter((f) => f?.file_id === 'shared-file-id')).toHaveLength(1);
|
||||
expect(result.attachments?.find((f) => f?.file_id === 'unique-file')).toBeDefined();
|
||||
});
|
||||
|
||||
it('should still categorize duplicate files for tool_resources', async () => {
|
||||
const sharedFile: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'shared-file-id',
|
||||
filename: 'script.py',
|
||||
filepath: '/uploads/script.py',
|
||||
object: 'file',
|
||||
type: 'text/x-python',
|
||||
bytes: 512,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
metadata: {
|
||||
fileIdentifier: 'python-script',
|
||||
},
|
||||
};
|
||||
|
||||
const mockOcrFiles: TFile[] = [sharedFile];
|
||||
const mockAttachmentFiles: TFile[] = [sharedFile];
|
||||
|
||||
mockGetFiles.mockResolvedValue(mockOcrFiles);
|
||||
const attachments = Promise.resolve(mockAttachmentFiles);
|
||||
|
||||
const tool_resources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['shared-file-id'],
|
||||
},
|
||||
};
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
// File should appear only once in attachments
|
||||
expect(result.attachments).toHaveLength(1);
|
||||
expect(result.attachments?.[0]?.file_id).toBe('shared-file-id');
|
||||
|
||||
// But should still be categorized in tool_resources
|
||||
expect(result.tool_resources?.[EToolResources.execute_code]?.files).toHaveLength(1);
|
||||
expect(result.tool_resources?.[EToolResources.execute_code]?.files?.[0]?.file_id).toBe(
|
||||
'shared-file-id',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle multiple duplicate files', async () => {
|
||||
const file1: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'file-1',
|
||||
filename: 'doc1.pdf',
|
||||
filepath: '/uploads/doc1.pdf',
|
||||
object: 'file',
|
||||
type: 'application/pdf',
|
||||
bytes: 1024,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
};
|
||||
|
||||
const file2: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'file-2',
|
||||
filename: 'doc2.pdf',
|
||||
filepath: '/uploads/doc2.pdf',
|
||||
object: 'file',
|
||||
type: 'application/pdf',
|
||||
bytes: 2048,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
};
|
||||
|
||||
const uniqueFile: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'unique-file',
|
||||
filename: 'unique.txt',
|
||||
filepath: '/uploads/unique.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
};
|
||||
|
||||
const mockOcrFiles: TFile[] = [file1, file2];
|
||||
const mockAttachmentFiles: TFile[] = [file1, file2, uniqueFile];
|
||||
|
||||
mockGetFiles.mockResolvedValue(mockOcrFiles);
|
||||
const attachments = Promise.resolve(mockAttachmentFiles);
|
||||
|
||||
const tool_resources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['file-1', 'file-2'],
|
||||
},
|
||||
};
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
// Should have 3 files total (2 from OCR + 1 unique from attachments)
|
||||
expect(result.attachments).toHaveLength(3);
|
||||
|
||||
// Each file should appear only once
|
||||
const fileIds = result.attachments?.map((f) => f?.file_id);
|
||||
expect(fileIds).toContain('file-1');
|
||||
expect(fileIds).toContain('file-2');
|
||||
expect(fileIds).toContain('unique-file');
|
||||
|
||||
// Check no duplicates
|
||||
const uniqueFileIds = new Set(fileIds);
|
||||
expect(uniqueFileIds.size).toBe(fileIds?.length);
|
||||
});
|
||||
|
||||
it('should handle files without file_id gracefully', async () => {
|
||||
const fileWithoutId: Partial<TFile> = {
|
||||
user: 'user1',
|
||||
filename: 'no-id.txt',
|
||||
filepath: '/uploads/no-id.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
};
|
||||
|
||||
const normalFile: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'normal-file',
|
||||
filename: 'normal.txt',
|
||||
filepath: '/uploads/normal.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 512,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
};
|
||||
|
||||
const mockOcrFiles: TFile[] = [normalFile];
|
||||
const mockAttachmentFiles = [fileWithoutId as TFile, normalFile];
|
||||
|
||||
mockGetFiles.mockResolvedValue(mockOcrFiles);
|
||||
const attachments = Promise.resolve(mockAttachmentFiles);
|
||||
|
||||
const tool_resources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['normal-file'],
|
||||
},
|
||||
};
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
// Should include file without ID and one instance of normal file
|
||||
expect(result.attachments).toHaveLength(2);
|
||||
expect(result.attachments?.filter((f) => f?.file_id === 'normal-file')).toHaveLength(1);
|
||||
expect(result.attachments?.some((f) => !f?.file_id)).toBe(true);
|
||||
});
|
||||
|
||||
it('should prevent duplicates from existing tool_resources', async () => {
|
||||
const existingFile: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'existing-file',
|
||||
filename: 'existing.py',
|
||||
filepath: '/uploads/existing.py',
|
||||
object: 'file',
|
||||
type: 'text/x-python',
|
||||
bytes: 512,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
metadata: {
|
||||
fileIdentifier: 'python-script',
|
||||
},
|
||||
};
|
||||
|
||||
const newFile: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'new-file',
|
||||
filename: 'new.py',
|
||||
filepath: '/uploads/new.py',
|
||||
object: 'file',
|
||||
type: 'text/x-python',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
metadata: {
|
||||
fileIdentifier: 'python-script',
|
||||
},
|
||||
};
|
||||
|
||||
const existingToolResources = {
|
||||
[EToolResources.execute_code]: {
|
||||
files: [existingFile],
|
||||
},
|
||||
};
|
||||
|
||||
const attachments = Promise.resolve([existingFile, newFile]);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: existingToolResources,
|
||||
});
|
||||
|
||||
// Should only add the new file to attachments
|
||||
expect(result.attachments).toHaveLength(1);
|
||||
expect(result.attachments?.[0]?.file_id).toBe('new-file');
|
||||
|
||||
// Should not duplicate the existing file in tool_resources
|
||||
expect(result.tool_resources?.[EToolResources.execute_code]?.files).toHaveLength(2);
|
||||
const fileIds = result.tool_resources?.[EToolResources.execute_code]?.files?.map(
|
||||
(f) => f.file_id,
|
||||
);
|
||||
expect(fileIds).toEqual(['existing-file', 'new-file']);
|
||||
});
|
||||
|
||||
it('should handle duplicates within attachments array', async () => {
|
||||
const duplicatedFile: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'dup-file',
|
||||
filename: 'duplicate.txt',
|
||||
filepath: '/uploads/duplicate.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
};
|
||||
|
||||
const uniqueFile: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'unique-file',
|
||||
filename: 'unique.txt',
|
||||
filepath: '/uploads/unique.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 128,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
};
|
||||
|
||||
// Same file appears multiple times in attachments
|
||||
const attachments = Promise.resolve([
|
||||
duplicatedFile,
|
||||
duplicatedFile,
|
||||
uniqueFile,
|
||||
duplicatedFile,
|
||||
]);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
// Should only have 2 unique files
|
||||
expect(result.attachments).toHaveLength(2);
|
||||
const fileIds = result.attachments?.map((f) => f?.file_id);
|
||||
expect(fileIds).toContain('dup-file');
|
||||
expect(fileIds).toContain('unique-file');
|
||||
|
||||
// Verify no duplicates
|
||||
expect(fileIds?.filter((id) => id === 'dup-file')).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('should prevent duplicates across different tool_resource categories', async () => {
|
||||
const multiPurposeFile: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'multi-file',
|
||||
filename: 'data.txt',
|
||||
filepath: '/uploads/data.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 512,
|
||||
embedded: true, // Will be categorized as file_search
|
||||
usage: 0,
|
||||
};
|
||||
|
||||
const existingToolResources = {
|
||||
[EToolResources.file_search]: {
|
||||
files: [multiPurposeFile],
|
||||
},
|
||||
};
|
||||
|
||||
// Try to add the same file again
|
||||
const attachments = Promise.resolve([multiPurposeFile]);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: existingToolResources,
|
||||
});
|
||||
|
||||
// Should not add to attachments (already exists)
|
||||
expect(result.attachments).toHaveLength(0);
|
||||
|
||||
// Should not duplicate in file_search
|
||||
expect(result.tool_resources?.[EToolResources.file_search]?.files).toHaveLength(1);
|
||||
expect(result.tool_resources?.[EToolResources.file_search]?.files?.[0]?.file_id).toBe(
|
||||
'multi-file',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle complex scenario with OCR, existing tool_resources, and attachments', async () => {
|
||||
const ocrFile: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'ocr-file',
|
||||
filename: 'scan.pdf',
|
||||
filepath: '/uploads/scan.pdf',
|
||||
object: 'file',
|
||||
type: 'application/pdf',
|
||||
bytes: 2048,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
};
|
||||
|
||||
const existingFile: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'existing-file',
|
||||
filename: 'code.py',
|
||||
filepath: '/uploads/code.py',
|
||||
object: 'file',
|
||||
type: 'text/x-python',
|
||||
bytes: 512,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
metadata: {
|
||||
fileIdentifier: 'python-script',
|
||||
},
|
||||
};
|
||||
|
||||
const newFile: TFile = {
|
||||
user: 'user1',
|
||||
file_id: 'new-file',
|
||||
filename: 'image.png',
|
||||
filepath: '/uploads/image.png',
|
||||
object: 'file',
|
||||
type: 'image/png',
|
||||
bytes: 4096,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
height: 800,
|
||||
width: 600,
|
||||
};
|
||||
|
||||
mockGetFiles.mockResolvedValue([ocrFile, existingFile]); // OCR returns both files
|
||||
const attachments = Promise.resolve([existingFile, ocrFile, newFile]); // Attachments has duplicates
|
||||
|
||||
const existingToolResources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['ocr-file', 'existing-file'],
|
||||
},
|
||||
[EToolResources.execute_code]: {
|
||||
files: [existingFile],
|
||||
},
|
||||
};
|
||||
|
||||
requestFileSet.add('new-file'); // Only new-file is in request set
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: existingToolResources,
|
||||
});
|
||||
|
||||
// Should have 3 unique files total
|
||||
expect(result.attachments).toHaveLength(3);
|
||||
const attachmentIds = result.attachments?.map((f) => f?.file_id).sort();
|
||||
expect(attachmentIds).toEqual(['existing-file', 'new-file', 'ocr-file']);
|
||||
|
||||
// Check tool_resources
|
||||
expect(result.tool_resources?.[EToolResources.execute_code]?.files).toHaveLength(1);
|
||||
expect(result.tool_resources?.[EToolResources.image_edit]?.files).toHaveLength(1);
|
||||
expect(result.tool_resources?.[EToolResources.image_edit]?.files?.[0]?.file_id).toBe(
|
||||
'new-file',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should handle errors gracefully and log them', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'test.txt',
|
||||
filepath: '/uploads/test.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
const error = new Error('Test error');
|
||||
|
||||
// Mock getFiles to throw an error when called for OCR
|
||||
mockGetFiles.mockRejectedValue(error);
|
||||
|
||||
const tool_resources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['ocr-file-1'],
|
||||
},
|
||||
};
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith('Error priming resources', error);
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources).toEqual(tool_resources);
|
||||
});
|
||||
|
||||
it('should handle promise rejection in attachments', async () => {
|
||||
const error = new Error('Attachment error');
|
||||
const attachments = Promise.reject(error);
|
||||
|
||||
// The function should now handle rejected attachment promises gracefully
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
// Should log both the main error and the attachment error
|
||||
expect(logger.error).toHaveBeenCalledWith('Error priming resources', error);
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'Error resolving attachments in catch block',
|
||||
error,
|
||||
);
|
||||
|
||||
// Should return empty array when attachments promise is rejected
|
||||
expect(result.attachments).toEqual([]);
|
||||
expect(result.tool_resources).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle missing app.locals gracefully', async () => {
|
||||
const reqWithoutLocals = {} as ServerRequest;
|
||||
|
||||
const result = await primeResources({
|
||||
req: reqWithoutLocals,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments: undefined,
|
||||
tool_resources: {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['ocr-file-1'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(mockGetFiles).not.toHaveBeenCalled();
|
||||
// When app.locals is missing and there's an error accessing properties,
|
||||
// the function falls back to the catch block which returns an empty array
|
||||
expect(result.attachments).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle undefined tool_resources', async () => {
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments: undefined,
|
||||
tool_resources: undefined,
|
||||
});
|
||||
|
||||
expect(result.tool_resources).toEqual({});
|
||||
expect(result.attachments).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle empty requestFileSet', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'image.png',
|
||||
filepath: '/uploads/image.png',
|
||||
object: 'file',
|
||||
type: 'image/png',
|
||||
bytes: 2048,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
height: 800,
|
||||
width: 600,
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
const emptyRequestFileSet = new Set<string>();
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet: emptyRequestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources?.[EToolResources.image_edit]).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
282
packages/api/src/agents/resources.ts
Normal file
282
packages/api/src/agents/resources.ts
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { EModelEndpoint, EToolResources, AgentCapabilities } from 'librechat-data-provider';
|
||||
import type { AgentToolResources, TFile, AgentBaseResource } from 'librechat-data-provider';
|
||||
import type { FilterQuery, QueryOptions, ProjectionType } from 'mongoose';
|
||||
import type { IMongoFile } from '@librechat/data-schemas';
|
||||
import type { Request as ServerRequest } from 'express';
|
||||
|
||||
/**
|
||||
* Function type for retrieving files from the database
|
||||
* @param filter - MongoDB filter query for files
|
||||
* @param _sortOptions - Sorting options (currently unused)
|
||||
* @param selectFields - Field selection options
|
||||
* @returns Promise resolving to array of files
|
||||
*/
|
||||
export type TGetFiles = (
|
||||
filter: FilterQuery<IMongoFile>,
|
||||
_sortOptions: ProjectionType<IMongoFile> | null | undefined,
|
||||
selectFields: QueryOptions<IMongoFile> | null | undefined,
|
||||
) => Promise<Array<TFile>>;
|
||||
|
||||
/**
|
||||
* Helper function to add a file to a specific tool resource category
|
||||
* Prevents duplicate files within the same resource category
|
||||
* @param params - Parameters object
|
||||
* @param params.file - The file to add to the resource
|
||||
* @param params.resourceType - The type of tool resource (e.g., execute_code, file_search, image_edit)
|
||||
* @param params.tool_resources - The agent's tool resources object to update
|
||||
* @param params.processedResourceFiles - Set tracking processed files per resource type
|
||||
*/
|
||||
const addFileToResource = ({
|
||||
file,
|
||||
resourceType,
|
||||
tool_resources,
|
||||
processedResourceFiles,
|
||||
}: {
|
||||
file: TFile;
|
||||
resourceType: EToolResources;
|
||||
tool_resources: AgentToolResources;
|
||||
processedResourceFiles: Set<string>;
|
||||
}): void => {
|
||||
if (!file.file_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
const resourceKey = `${resourceType}:${file.file_id}`;
|
||||
if (processedResourceFiles.has(resourceKey)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const resource = tool_resources[resourceType as keyof AgentToolResources] ?? {};
|
||||
if (!resource.files) {
|
||||
(tool_resources[resourceType as keyof AgentToolResources] as AgentBaseResource) = {
|
||||
...resource,
|
||||
files: [],
|
||||
};
|
||||
}
|
||||
|
||||
// Check if already exists in the files array
|
||||
const resourceFiles = tool_resources[resourceType as keyof AgentToolResources]?.files;
|
||||
const alreadyExists = resourceFiles?.some((f: TFile) => f.file_id === file.file_id);
|
||||
|
||||
if (!alreadyExists) {
|
||||
resourceFiles?.push(file);
|
||||
processedResourceFiles.add(resourceKey);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Categorizes a file into the appropriate tool resource based on its properties
|
||||
* Files are categorized as:
|
||||
* - execute_code: Files with fileIdentifier metadata
|
||||
* - file_search: Files marked as embedded
|
||||
* - image_edit: Image files in the request file set with dimensions
|
||||
* @param params - Parameters object
|
||||
* @param params.file - The file to categorize
|
||||
* @param params.tool_resources - The agent's tool resources to update
|
||||
* @param params.requestFileSet - Set of file IDs from the current request
|
||||
* @param params.processedResourceFiles - Set tracking processed files per resource type
|
||||
*/
|
||||
const categorizeFileForToolResources = ({
|
||||
file,
|
||||
tool_resources,
|
||||
requestFileSet,
|
||||
processedResourceFiles,
|
||||
}: {
|
||||
file: TFile;
|
||||
tool_resources: AgentToolResources;
|
||||
requestFileSet: Set<string>;
|
||||
processedResourceFiles: Set<string>;
|
||||
}): void => {
|
||||
if (file.metadata?.fileIdentifier) {
|
||||
addFileToResource({
|
||||
file,
|
||||
resourceType: EToolResources.execute_code,
|
||||
tool_resources,
|
||||
processedResourceFiles,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (file.embedded === true) {
|
||||
addFileToResource({
|
||||
file,
|
||||
resourceType: EToolResources.file_search,
|
||||
tool_resources,
|
||||
processedResourceFiles,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
requestFileSet.has(file.file_id) &&
|
||||
file.type.startsWith('image') &&
|
||||
file.height &&
|
||||
file.width
|
||||
) {
|
||||
addFileToResource({
|
||||
file,
|
||||
resourceType: EToolResources.image_edit,
|
||||
tool_resources,
|
||||
processedResourceFiles,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Primes resources for agent execution by processing attachments and tool resources
|
||||
* This function:
|
||||
* 1. Fetches OCR files if OCR is enabled
|
||||
* 2. Processes attachment files
|
||||
* 3. Categorizes files into appropriate tool resources
|
||||
* 4. Prevents duplicate files across all sources
|
||||
*
|
||||
* @param params - Parameters object
|
||||
* @param params.req - Express request object containing app configuration
|
||||
* @param params.getFiles - Function to retrieve files from database
|
||||
* @param params.requestFileSet - Set of file IDs from the current request
|
||||
* @param params.attachments - Promise resolving to array of attachment files
|
||||
* @param params.tool_resources - Existing tool resources for the agent
|
||||
* @returns Promise resolving to processed attachments and updated tool resources
|
||||
*/
|
||||
export const primeResources = async ({
|
||||
req,
|
||||
getFiles,
|
||||
requestFileSet,
|
||||
attachments: _attachments,
|
||||
tool_resources: _tool_resources,
|
||||
}: {
|
||||
req: ServerRequest;
|
||||
requestFileSet: Set<string>;
|
||||
attachments: Promise<Array<TFile | null>> | undefined;
|
||||
tool_resources: AgentToolResources | undefined;
|
||||
getFiles: TGetFiles;
|
||||
}): Promise<{
|
||||
attachments: Array<TFile | undefined> | undefined;
|
||||
tool_resources: AgentToolResources | undefined;
|
||||
}> => {
|
||||
try {
|
||||
/**
|
||||
* Array to collect all unique files that will be returned as attachments
|
||||
* Files are added from OCR results and attachment promises, with duplicates prevented
|
||||
*/
|
||||
const attachments: Array<TFile> = [];
|
||||
/**
|
||||
* Set of file IDs already added to the attachments array
|
||||
* Used to prevent duplicate files from being added multiple times
|
||||
* Pre-populated with files from non-OCR tool_resources to prevent re-adding them
|
||||
*/
|
||||
const attachmentFileIds = new Set<string>();
|
||||
/**
|
||||
* Set tracking which files have been added to specific tool resource categories
|
||||
* Format: "resourceType:fileId" (e.g., "execute_code:file123")
|
||||
* Prevents the same file from being added multiple times to the same resource
|
||||
*/
|
||||
const processedResourceFiles = new Set<string>();
|
||||
/**
|
||||
* The agent's tool resources object that will be updated with categorized files
|
||||
* Initialized from input parameter or empty object if not provided
|
||||
*/
|
||||
const tool_resources = _tool_resources ?? {};
|
||||
|
||||
// Track existing files in tool_resources to prevent duplicates within resources
|
||||
for (const [resourceType, resource] of Object.entries(tool_resources)) {
|
||||
if (resource?.files && Array.isArray(resource.files)) {
|
||||
for (const file of resource.files) {
|
||||
if (file?.file_id) {
|
||||
processedResourceFiles.add(`${resourceType}:${file.file_id}`);
|
||||
// Files from non-OCR resources should not be added to attachments from _attachments
|
||||
if (resourceType !== EToolResources.ocr) {
|
||||
attachmentFileIds.add(file.file_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes(
|
||||
AgentCapabilities.ocr,
|
||||
);
|
||||
|
||||
if (tool_resources[EToolResources.ocr]?.file_ids && isOCREnabled) {
|
||||
const context = await getFiles(
|
||||
{
|
||||
file_id: { $in: tool_resources.ocr.file_ids },
|
||||
},
|
||||
{},
|
||||
{},
|
||||
);
|
||||
|
||||
for (const file of context) {
|
||||
if (!file?.file_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Clear from attachmentFileIds if it was pre-added
|
||||
attachmentFileIds.delete(file.file_id);
|
||||
|
||||
// Add to attachments
|
||||
attachments.push(file);
|
||||
attachmentFileIds.add(file.file_id);
|
||||
|
||||
// Categorize for tool resources
|
||||
categorizeFileForToolResources({
|
||||
file,
|
||||
tool_resources,
|
||||
requestFileSet,
|
||||
processedResourceFiles,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!_attachments) {
|
||||
return { attachments: attachments.length > 0 ? attachments : undefined, tool_resources };
|
||||
}
|
||||
|
||||
const files = await _attachments;
|
||||
|
||||
for (const file of files) {
|
||||
if (!file) {
|
||||
continue;
|
||||
}
|
||||
|
||||
categorizeFileForToolResources({
|
||||
file,
|
||||
tool_resources,
|
||||
requestFileSet,
|
||||
processedResourceFiles,
|
||||
});
|
||||
|
||||
if (file.file_id && attachmentFileIds.has(file.file_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
attachments.push(file);
|
||||
if (file.file_id) {
|
||||
attachmentFileIds.add(file.file_id);
|
||||
}
|
||||
}
|
||||
|
||||
return { attachments: attachments.length > 0 ? attachments : [], tool_resources };
|
||||
} catch (error) {
|
||||
logger.error('Error priming resources', error);
|
||||
|
||||
// Safely try to get attachments without rethrowing
|
||||
let safeAttachments: Array<TFile | undefined> = [];
|
||||
if (_attachments) {
|
||||
try {
|
||||
const attachmentFiles = await _attachments;
|
||||
safeAttachments = (attachmentFiles?.filter((file) => !!file) ?? []) as Array<TFile>;
|
||||
} catch (attachmentError) {
|
||||
// If attachments promise is also rejected, just use empty array
|
||||
logger.error('Error resolving attachments in catch block', attachmentError);
|
||||
safeAttachments = [];
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
attachments: safeAttachments,
|
||||
tool_resources: _tool_resources,
|
||||
};
|
||||
}
|
||||
};
|
||||
96
packages/api/src/agents/run.ts
Normal file
96
packages/api/src/agents/run.ts
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import { Run, Providers } from '@librechat/agents';
|
||||
import { providerEndpointMap, KnownEndpoints } from 'librechat-data-provider';
|
||||
import type {
|
||||
StandardGraphConfig,
|
||||
EventHandler,
|
||||
GenericTool,
|
||||
GraphEvents,
|
||||
IState,
|
||||
} from '@librechat/agents';
|
||||
import type { Agent } from 'librechat-data-provider';
|
||||
import type * as t from '~/types';
|
||||
|
||||
const customProviders = new Set([
|
||||
Providers.XAI,
|
||||
Providers.OLLAMA,
|
||||
Providers.DEEPSEEK,
|
||||
Providers.OPENROUTER,
|
||||
]);
|
||||
|
||||
/**
|
||||
* Creates a new Run instance with custom handlers and configuration.
|
||||
*
|
||||
* @param options - The options for creating the Run instance.
|
||||
* @param options.agent - The agent for this run.
|
||||
* @param options.signal - The signal for this run.
|
||||
* @param options.req - The server request.
|
||||
* @param options.runId - Optional run ID; otherwise, a new run ID will be generated.
|
||||
* @param options.customHandlers - Custom event handlers.
|
||||
* @param options.streaming - Whether to use streaming.
|
||||
* @param options.streamUsage - Whether to stream usage information.
|
||||
* @returns {Promise<Run<IState>>} A promise that resolves to a new Run instance.
|
||||
*/
|
||||
export async function createRun({
|
||||
runId,
|
||||
agent,
|
||||
signal,
|
||||
customHandlers,
|
||||
streaming = true,
|
||||
streamUsage = true,
|
||||
}: {
|
||||
agent: Omit<Agent, 'tools'> & { tools?: GenericTool[] };
|
||||
signal: AbortSignal;
|
||||
runId?: string;
|
||||
streaming?: boolean;
|
||||
streamUsage?: boolean;
|
||||
customHandlers?: Record<GraphEvents, EventHandler>;
|
||||
}): Promise<Run<IState>> {
|
||||
const provider =
|
||||
providerEndpointMap[agent.provider as keyof typeof providerEndpointMap] ?? agent.provider;
|
||||
const llmConfig: t.RunLLMConfig = Object.assign(
|
||||
{
|
||||
provider,
|
||||
streaming,
|
||||
streamUsage,
|
||||
},
|
||||
agent.model_parameters,
|
||||
);
|
||||
|
||||
/** Resolves issues with new OpenAI usage field */
|
||||
if (
|
||||
customProviders.has(agent.provider) ||
|
||||
(agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider)
|
||||
) {
|
||||
llmConfig.streamUsage = false;
|
||||
llmConfig.usage = true;
|
||||
}
|
||||
|
||||
let reasoningKey: 'reasoning_content' | 'reasoning' | undefined;
|
||||
if (
|
||||
llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) ||
|
||||
(agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
|
||||
) {
|
||||
reasoningKey = 'reasoning';
|
||||
}
|
||||
|
||||
const graphConfig: StandardGraphConfig = {
|
||||
signal,
|
||||
llmConfig,
|
||||
reasoningKey,
|
||||
tools: agent.tools,
|
||||
instructions: agent.instructions,
|
||||
additional_instructions: agent.additional_instructions,
|
||||
// toolEnd: agent.end_after_tools,
|
||||
};
|
||||
|
||||
// TEMPORARY FOR TESTING
|
||||
if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) {
|
||||
graphConfig.streamBuffer = 2000;
|
||||
}
|
||||
|
||||
return Run.create({
|
||||
runId,
|
||||
graphConfig,
|
||||
customHandlers,
|
||||
});
|
||||
}
|
||||
129
packages/api/src/crypto/encryption.ts
Normal file
129
packages/api/src/crypto/encryption.ts
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
import 'dotenv/config';
|
||||
import crypto from 'node:crypto';
|
||||
const { webcrypto } = crypto;
|
||||
|
||||
// Use hex decoding for both key and IV for legacy methods.
|
||||
const key = Buffer.from(process.env.CREDS_KEY ?? '', 'hex');
|
||||
const iv = Buffer.from(process.env.CREDS_IV ?? '', 'hex');
|
||||
const algorithm = 'AES-CBC';
|
||||
|
||||
// --- Legacy v1/v2 Setup: AES-CBC with fixed key and IV ---
|
||||
|
||||
export async function encrypt(value: string) {
|
||||
const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [
|
||||
'encrypt',
|
||||
]);
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(value);
|
||||
const encryptedBuffer = await webcrypto.subtle.encrypt(
|
||||
{ name: algorithm, iv: iv },
|
||||
cryptoKey,
|
||||
data,
|
||||
);
|
||||
return Buffer.from(encryptedBuffer).toString('hex');
|
||||
}
|
||||
|
||||
export async function decrypt(encryptedValue: string) {
|
||||
const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [
|
||||
'decrypt',
|
||||
]);
|
||||
const encryptedBuffer = Buffer.from(encryptedValue, 'hex');
|
||||
const decryptedBuffer = await webcrypto.subtle.decrypt(
|
||||
{ name: algorithm, iv: iv },
|
||||
cryptoKey,
|
||||
encryptedBuffer,
|
||||
);
|
||||
const decoder = new TextDecoder();
|
||||
return decoder.decode(decryptedBuffer);
|
||||
}
|
||||
|
||||
// --- v2: AES-CBC with a random IV per encryption ---
|
||||
|
||||
export async function encryptV2(value: string) {
|
||||
const gen_iv = webcrypto.getRandomValues(new Uint8Array(16));
|
||||
const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [
|
||||
'encrypt',
|
||||
]);
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(value);
|
||||
const encryptedBuffer = await webcrypto.subtle.encrypt(
|
||||
{ name: algorithm, iv: gen_iv },
|
||||
cryptoKey,
|
||||
data,
|
||||
);
|
||||
return Buffer.from(gen_iv).toString('hex') + ':' + Buffer.from(encryptedBuffer).toString('hex');
|
||||
}
|
||||
|
||||
export async function decryptV2(encryptedValue: string) {
|
||||
const parts = encryptedValue.split(':');
|
||||
if (parts.length === 1) {
|
||||
return parts[0];
|
||||
}
|
||||
const gen_iv = Buffer.from(parts.shift() ?? '', 'hex');
|
||||
const encrypted = parts.join(':');
|
||||
const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [
|
||||
'decrypt',
|
||||
]);
|
||||
const encryptedBuffer = Buffer.from(encrypted, 'hex');
|
||||
const decryptedBuffer = await webcrypto.subtle.decrypt(
|
||||
{ name: algorithm, iv: gen_iv },
|
||||
cryptoKey,
|
||||
encryptedBuffer,
|
||||
);
|
||||
const decoder = new TextDecoder();
|
||||
return decoder.decode(decryptedBuffer);
|
||||
}
|
||||
|
||||
// --- v3: AES-256-CTR using Node's crypto functions ---
|
||||
const algorithm_v3 = 'aes-256-ctr';
|
||||
|
||||
/**
|
||||
* Encrypts a value using AES-256-CTR.
|
||||
* Note: AES-256 requires a 32-byte key. Ensure that process.env.CREDS_KEY is a 64-character hex string.
|
||||
*
|
||||
* @param value - The plaintext to encrypt.
|
||||
* @returns The encrypted string with a "v3:" prefix.
|
||||
*/
|
||||
export function encryptV3(value: string) {
|
||||
if (key.length !== 32) {
|
||||
throw new Error(`Invalid key length: expected 32 bytes, got ${key.length} bytes`);
|
||||
}
|
||||
const iv_v3 = crypto.randomBytes(16);
|
||||
const cipher = crypto.createCipheriv(algorithm_v3, key, iv_v3);
|
||||
const encrypted = Buffer.concat([cipher.update(value, 'utf8'), cipher.final()]);
|
||||
return `v3:${iv_v3.toString('hex')}:${encrypted.toString('hex')}`;
|
||||
}
|
||||
|
||||
export function decryptV3(encryptedValue: string) {
|
||||
const parts = encryptedValue.split(':');
|
||||
if (parts[0] !== 'v3') {
|
||||
throw new Error('Not a v3 encrypted value');
|
||||
}
|
||||
const iv_v3 = Buffer.from(parts[1], 'hex');
|
||||
const encryptedText = Buffer.from(parts.slice(2).join(':'), 'hex');
|
||||
const decipher = crypto.createDecipheriv(algorithm_v3, key, iv_v3);
|
||||
const decrypted = Buffer.concat([decipher.update(encryptedText), decipher.final()]);
|
||||
return decrypted.toString('utf8');
|
||||
}
|
||||
|
||||
export async function getRandomValues(length: number) {
|
||||
if (!Number.isInteger(length) || length <= 0) {
|
||||
throw new Error('Length must be a positive integer');
|
||||
}
|
||||
const randomValues = new Uint8Array(length);
|
||||
webcrypto.getRandomValues(randomValues);
|
||||
return Buffer.from(randomValues).toString('hex');
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes SHA-256 hash for the given input.
|
||||
* @param input - The input to hash.
|
||||
* @returns The SHA-256 hash of the input.
|
||||
*/
|
||||
export async function hashBackupCode(input: string) {
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(input);
|
||||
const hashBuffer = await webcrypto.subtle.digest('SHA-256', data);
|
||||
const hashArray = Array.from(new Uint8Array(hashBuffer));
|
||||
return hashArray.map((b) => b.toString(16).padStart(2, '0')).join('');
|
||||
}
|
||||
1
packages/api/src/crypto/index.ts
Normal file
1
packages/api/src/crypto/index.ts
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from './encryption';
|
||||
1
packages/api/src/endpoints/index.ts
Normal file
1
packages/api/src/endpoints/index.ts
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from './openai';
|
||||
2
packages/api/src/endpoints/openai/index.ts
Normal file
2
packages/api/src/endpoints/openai/index.ts
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
export * from './llm';
|
||||
export * from './initialize';
|
||||
176
packages/api/src/endpoints/openai/initialize.ts
Normal file
176
packages/api/src/endpoints/openai/initialize.ts
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
import {
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
resolveHeaders,
|
||||
mapModelToAzureConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import type {
|
||||
LLMConfigOptions,
|
||||
UserKeyValues,
|
||||
InitializeOpenAIOptionsParams,
|
||||
OpenAIOptionsResult,
|
||||
} from '~/types';
|
||||
import { createHandleLLMNewToken } from '~/utils/generators';
|
||||
import { getAzureCredentials } from '~/utils/azure';
|
||||
import { isUserProvided } from '~/utils/common';
|
||||
import { getOpenAIConfig } from './llm';
|
||||
|
||||
/**
|
||||
* Initializes OpenAI options for agent usage. This function always returns configuration
|
||||
* options and never creates a client instance (equivalent to optionsOnly=true behavior).
|
||||
*
|
||||
* @param params - Configuration parameters
|
||||
* @returns Promise resolving to OpenAI configuration options
|
||||
* @throws Error if API key is missing or user key has expired
|
||||
*/
|
||||
export const initializeOpenAI = async ({
|
||||
req,
|
||||
overrideModel,
|
||||
endpointOption,
|
||||
overrideEndpoint,
|
||||
getUserKeyValues,
|
||||
checkUserKeyExpiry,
|
||||
}: InitializeOpenAIOptionsParams): Promise<OpenAIOptionsResult> => {
|
||||
const { PROXY, OPENAI_API_KEY, AZURE_API_KEY, OPENAI_REVERSE_PROXY, AZURE_OPENAI_BASEURL } =
|
||||
process.env;
|
||||
|
||||
const { key: expiresAt } = req.body;
|
||||
const modelName = overrideModel ?? req.body.model;
|
||||
const endpoint = overrideEndpoint ?? req.body.endpoint;
|
||||
|
||||
if (!endpoint) {
|
||||
throw new Error('Endpoint is required');
|
||||
}
|
||||
|
||||
const credentials = {
|
||||
[EModelEndpoint.openAI]: OPENAI_API_KEY,
|
||||
[EModelEndpoint.azureOpenAI]: AZURE_API_KEY,
|
||||
};
|
||||
|
||||
const baseURLOptions = {
|
||||
[EModelEndpoint.openAI]: OPENAI_REVERSE_PROXY,
|
||||
[EModelEndpoint.azureOpenAI]: AZURE_OPENAI_BASEURL,
|
||||
};
|
||||
|
||||
const userProvidesKey = isUserProvided(credentials[endpoint as keyof typeof credentials]);
|
||||
const userProvidesURL = isUserProvided(baseURLOptions[endpoint as keyof typeof baseURLOptions]);
|
||||
|
||||
let userValues: UserKeyValues | null = null;
|
||||
if (expiresAt && (userProvidesKey || userProvidesURL)) {
|
||||
checkUserKeyExpiry(expiresAt, endpoint);
|
||||
userValues = await getUserKeyValues({ userId: req.user.id, name: endpoint });
|
||||
}
|
||||
|
||||
let apiKey = userProvidesKey
|
||||
? userValues?.apiKey
|
||||
: credentials[endpoint as keyof typeof credentials];
|
||||
const baseURL = userProvidesURL
|
||||
? userValues?.baseURL
|
||||
: baseURLOptions[endpoint as keyof typeof baseURLOptions];
|
||||
|
||||
const clientOptions: LLMConfigOptions = {
|
||||
proxy: PROXY ?? undefined,
|
||||
reverseProxyUrl: baseURL || undefined,
|
||||
streaming: true,
|
||||
};
|
||||
|
||||
const isAzureOpenAI = endpoint === EModelEndpoint.azureOpenAI;
|
||||
const azureConfig = isAzureOpenAI && req.app.locals[EModelEndpoint.azureOpenAI];
|
||||
|
||||
if (isAzureOpenAI && azureConfig) {
|
||||
const { modelGroupMap, groupMap } = azureConfig;
|
||||
const {
|
||||
azureOptions,
|
||||
baseURL: configBaseURL,
|
||||
headers = {},
|
||||
serverless,
|
||||
} = mapModelToAzureConfig({
|
||||
modelName: modelName || '',
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
|
||||
clientOptions.reverseProxyUrl = configBaseURL ?? clientOptions.reverseProxyUrl;
|
||||
clientOptions.headers = resolveHeaders({ ...headers, ...(clientOptions.headers ?? {}) });
|
||||
|
||||
const groupName = modelGroupMap[modelName || '']?.group;
|
||||
if (groupName && groupMap[groupName]) {
|
||||
clientOptions.addParams = groupMap[groupName]?.addParams;
|
||||
clientOptions.dropParams = groupMap[groupName]?.dropParams;
|
||||
}
|
||||
|
||||
apiKey = azureOptions.azureOpenAIApiKey;
|
||||
clientOptions.azure = !serverless ? azureOptions : undefined;
|
||||
|
||||
if (serverless === true) {
|
||||
clientOptions.defaultQuery = azureOptions.azureOpenAIApiVersion
|
||||
? { 'api-version': azureOptions.azureOpenAIApiVersion }
|
||||
: undefined;
|
||||
|
||||
if (!clientOptions.headers) {
|
||||
clientOptions.headers = {};
|
||||
}
|
||||
clientOptions.headers['api-key'] = apiKey;
|
||||
}
|
||||
} else if (isAzureOpenAI) {
|
||||
clientOptions.azure =
|
||||
userProvidesKey && userValues?.apiKey ? JSON.parse(userValues.apiKey) : getAzureCredentials();
|
||||
apiKey = clientOptions.azure?.azureOpenAIApiKey;
|
||||
}
|
||||
|
||||
if (userProvidesKey && !apiKey) {
|
||||
throw new Error(
|
||||
JSON.stringify({
|
||||
type: ErrorTypes.NO_USER_KEY,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
if (!apiKey) {
|
||||
throw new Error(`${endpoint} API Key not provided.`);
|
||||
}
|
||||
|
||||
const modelOptions = {
|
||||
...endpointOption.model_parameters,
|
||||
model: modelName,
|
||||
user: req.user.id,
|
||||
};
|
||||
|
||||
const finalClientOptions: LLMConfigOptions = {
|
||||
...clientOptions,
|
||||
modelOptions,
|
||||
};
|
||||
|
||||
const options = getOpenAIConfig(apiKey, finalClientOptions, endpoint);
|
||||
|
||||
const openAIConfig = req.app.locals[EModelEndpoint.openAI];
|
||||
const allConfig = req.app.locals.all;
|
||||
const azureRate = modelName?.includes('gpt-4') ? 30 : 17;
|
||||
|
||||
let streamRate: number | undefined;
|
||||
|
||||
if (isAzureOpenAI && azureConfig) {
|
||||
streamRate = azureConfig.streamRate ?? azureRate;
|
||||
} else if (!isAzureOpenAI && openAIConfig) {
|
||||
streamRate = openAIConfig.streamRate;
|
||||
}
|
||||
|
||||
if (allConfig?.streamRate) {
|
||||
streamRate = allConfig.streamRate;
|
||||
}
|
||||
|
||||
if (streamRate) {
|
||||
options.llmConfig.callbacks = [
|
||||
{
|
||||
handleLLMNewToken: createHandleLLMNewToken(streamRate),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
const result: OpenAIOptionsResult = {
|
||||
...options,
|
||||
streamRate,
|
||||
};
|
||||
|
||||
return result;
|
||||
};
|
||||
158
packages/api/src/endpoints/openai/llm.ts
Normal file
158
packages/api/src/endpoints/openai/llm.ts
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
import { ProxyAgent } from 'undici';
|
||||
import { KnownEndpoints } from 'librechat-data-provider';
|
||||
import type * as t from '~/types';
|
||||
import { sanitizeModelName, constructAzureURL } from '~/utils/azure';
|
||||
import { isEnabled } from '~/utils/common';
|
||||
|
||||
/**
|
||||
* Generates configuration options for creating a language model (LLM) instance.
|
||||
* @param apiKey - The API key for authentication.
|
||||
* @param options - Additional options for configuring the LLM.
|
||||
* @param endpoint - The endpoint name
|
||||
* @returns Configuration options for creating an LLM instance.
|
||||
*/
|
||||
export function getOpenAIConfig(
|
||||
apiKey: string,
|
||||
options: t.LLMConfigOptions = {},
|
||||
endpoint?: string | null,
|
||||
): t.LLMConfigResult {
|
||||
const {
|
||||
modelOptions = {},
|
||||
reverseProxyUrl,
|
||||
defaultQuery,
|
||||
headers,
|
||||
proxy,
|
||||
azure,
|
||||
streaming = true,
|
||||
addParams,
|
||||
dropParams,
|
||||
} = options;
|
||||
|
||||
const llmConfig: Partial<t.ClientOptions> & Partial<t.OpenAIParameters> = Object.assign(
|
||||
{
|
||||
streaming,
|
||||
model: modelOptions.model ?? '',
|
||||
},
|
||||
modelOptions,
|
||||
);
|
||||
|
||||
if (addParams && typeof addParams === 'object') {
|
||||
Object.assign(llmConfig, addParams);
|
||||
}
|
||||
|
||||
// Note: OpenAI Web Search models do not support any known parameters besides `max_tokens`
|
||||
if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) {
|
||||
const searchExcludeParams = [
|
||||
'frequency_penalty',
|
||||
'presence_penalty',
|
||||
'temperature',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'stop',
|
||||
'logit_bias',
|
||||
'seed',
|
||||
'response_format',
|
||||
'n',
|
||||
'logprobs',
|
||||
'user',
|
||||
];
|
||||
|
||||
const updatedDropParams = dropParams || [];
|
||||
const combinedDropParams = [...new Set([...updatedDropParams, ...searchExcludeParams])];
|
||||
|
||||
combinedDropParams.forEach((param) => {
|
||||
if (param in llmConfig) {
|
||||
delete llmConfig[param as keyof t.ClientOptions];
|
||||
}
|
||||
});
|
||||
} else if (dropParams && Array.isArray(dropParams)) {
|
||||
dropParams.forEach((param) => {
|
||||
if (param in llmConfig) {
|
||||
delete llmConfig[param as keyof t.ClientOptions];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let useOpenRouter = false;
|
||||
const configOptions: t.OpenAIConfiguration = {};
|
||||
|
||||
if (
|
||||
(reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) ||
|
||||
(endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
|
||||
) {
|
||||
useOpenRouter = true;
|
||||
llmConfig.include_reasoning = true;
|
||||
configOptions.baseURL = reverseProxyUrl;
|
||||
configOptions.defaultHeaders = Object.assign(
|
||||
{
|
||||
'HTTP-Referer': 'https://librechat.ai',
|
||||
'X-Title': 'LibreChat',
|
||||
},
|
||||
headers,
|
||||
);
|
||||
} else if (reverseProxyUrl) {
|
||||
configOptions.baseURL = reverseProxyUrl;
|
||||
if (headers) {
|
||||
configOptions.defaultHeaders = headers;
|
||||
}
|
||||
}
|
||||
|
||||
if (defaultQuery) {
|
||||
configOptions.defaultQuery = defaultQuery;
|
||||
}
|
||||
|
||||
if (proxy) {
|
||||
const proxyAgent = new ProxyAgent(proxy);
|
||||
configOptions.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
}
|
||||
|
||||
if (azure) {
|
||||
const useModelName = isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME);
|
||||
const updatedAzure = { ...azure };
|
||||
updatedAzure.azureOpenAIApiDeploymentName = useModelName
|
||||
? sanitizeModelName(llmConfig.model || '')
|
||||
: azure.azureOpenAIApiDeploymentName;
|
||||
|
||||
if (process.env.AZURE_OPENAI_DEFAULT_MODEL) {
|
||||
llmConfig.model = process.env.AZURE_OPENAI_DEFAULT_MODEL;
|
||||
}
|
||||
|
||||
if (configOptions.baseURL) {
|
||||
const azureURL = constructAzureURL({
|
||||
baseURL: configOptions.baseURL,
|
||||
azureOptions: updatedAzure,
|
||||
});
|
||||
updatedAzure.azureOpenAIBasePath = azureURL.split(
|
||||
`/${updatedAzure.azureOpenAIApiDeploymentName}`,
|
||||
)[0];
|
||||
}
|
||||
|
||||
Object.assign(llmConfig, updatedAzure);
|
||||
llmConfig.model = updatedAzure.azureOpenAIApiDeploymentName;
|
||||
} else {
|
||||
llmConfig.apiKey = apiKey;
|
||||
}
|
||||
|
||||
if (process.env.OPENAI_ORGANIZATION && azure) {
|
||||
configOptions.organization = process.env.OPENAI_ORGANIZATION;
|
||||
}
|
||||
|
||||
if (useOpenRouter && llmConfig.reasoning_effort != null) {
|
||||
llmConfig.reasoning = {
|
||||
effort: llmConfig.reasoning_effort,
|
||||
};
|
||||
delete llmConfig.reasoning_effort;
|
||||
}
|
||||
|
||||
if (llmConfig.max_tokens != null) {
|
||||
llmConfig.maxTokens = llmConfig.max_tokens;
|
||||
delete llmConfig.max_tokens;
|
||||
}
|
||||
|
||||
return {
|
||||
llmConfig,
|
||||
configOptions,
|
||||
};
|
||||
}
|
||||
1
packages/api/src/files/index.ts
Normal file
1
packages/api/src/files/index.ts
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from './mistral/crud';
|
||||
1570
packages/api/src/files/mistral/crud.spec.ts
Normal file
1570
packages/api/src/files/mistral/crud.spec.ts
Normal file
File diff suppressed because it is too large
Load diff
416
packages/api/src/files/mistral/crud.ts
Normal file
416
packages/api/src/files/mistral/crud.ts
Normal file
|
|
@ -0,0 +1,416 @@
|
|||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
import FormData from 'form-data';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import {
|
||||
FileSources,
|
||||
envVarRegex,
|
||||
extractEnvVariable,
|
||||
extractVariableName,
|
||||
} from 'librechat-data-provider';
|
||||
import type { TCustomConfig } from 'librechat-data-provider';
|
||||
import type { Request as ServerRequest } from 'express';
|
||||
import type { AxiosError } from 'axios';
|
||||
import type {
|
||||
MistralFileUploadResponse,
|
||||
MistralSignedUrlResponse,
|
||||
MistralOCRUploadResult,
|
||||
MistralOCRError,
|
||||
OCRResultPage,
|
||||
OCRResult,
|
||||
OCRImage,
|
||||
} from '~/types';
|
||||
import { logAxiosError, createAxiosInstance } from '~/utils/axios';
|
||||
|
||||
const axios = createAxiosInstance();
|
||||
const DEFAULT_MISTRAL_BASE_URL = 'https://api.mistral.ai/v1';
|
||||
const DEFAULT_MISTRAL_MODEL = 'mistral-ocr-latest';
|
||||
|
||||
/** Helper type for auth configuration */
|
||||
interface AuthConfig {
|
||||
apiKey: string;
|
||||
baseURL: string;
|
||||
}
|
||||
|
||||
/** Helper type for OCR request context */
|
||||
interface OCRContext {
|
||||
req: Pick<ServerRequest, 'user' | 'app'> & {
|
||||
user?: { id: string };
|
||||
app: {
|
||||
locals?: {
|
||||
ocr?: TCustomConfig['ocr'];
|
||||
};
|
||||
};
|
||||
};
|
||||
file: Express.Multer.File;
|
||||
loadAuthValues: (params: {
|
||||
userId: string;
|
||||
authFields: string[];
|
||||
optional?: Set<string>;
|
||||
}) => Promise<Record<string, string | undefined>>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory
|
||||
* @param params Upload parameters
|
||||
* @param params.filePath The path to the file on disk
|
||||
* @param params.fileName Optional filename to use (defaults to the name from filePath)
|
||||
* @param params.apiKey Mistral API key
|
||||
* @param params.baseURL Mistral API base URL
|
||||
* @returns The response from Mistral API
|
||||
*/
|
||||
export async function uploadDocumentToMistral({
|
||||
apiKey,
|
||||
filePath,
|
||||
baseURL = DEFAULT_MISTRAL_BASE_URL,
|
||||
fileName = '',
|
||||
}: {
|
||||
apiKey: string;
|
||||
filePath: string;
|
||||
baseURL?: string;
|
||||
fileName?: string;
|
||||
}): Promise<MistralFileUploadResponse> {
|
||||
const form = new FormData();
|
||||
form.append('purpose', 'ocr');
|
||||
const actualFileName = fileName || path.basename(filePath);
|
||||
const fileStream = fs.createReadStream(filePath);
|
||||
form.append('file', fileStream, { filename: actualFileName });
|
||||
|
||||
return axios
|
||||
.post(`${baseURL}/files`, form, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
...form.getHeaders(),
|
||||
},
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
})
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
export async function getSignedUrl({
|
||||
apiKey,
|
||||
fileId,
|
||||
expiry = 24,
|
||||
baseURL = DEFAULT_MISTRAL_BASE_URL,
|
||||
}: {
|
||||
apiKey: string;
|
||||
fileId: string;
|
||||
expiry?: number;
|
||||
baseURL?: string;
|
||||
}): Promise<MistralSignedUrlResponse> {
|
||||
return axios
|
||||
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
})
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error fetching signed URL:', error.message);
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {string} params.apiKey
|
||||
* @param {string} params.url - The document or image URL
|
||||
* @param {string} [params.documentType='document_url'] - 'document_url' or 'image_url'
|
||||
* @param {string} [params.model]
|
||||
* @param {string} [params.baseURL]
|
||||
* @returns {Promise<OCRResult>}
|
||||
*/
|
||||
export async function performOCR({
|
||||
url,
|
||||
apiKey,
|
||||
model = DEFAULT_MISTRAL_MODEL,
|
||||
baseURL = DEFAULT_MISTRAL_BASE_URL,
|
||||
documentType = 'document_url',
|
||||
}: {
|
||||
url: string;
|
||||
apiKey: string;
|
||||
model?: string;
|
||||
baseURL?: string;
|
||||
documentType?: 'document_url' | 'image_url';
|
||||
}): Promise<OCRResult> {
|
||||
const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url';
|
||||
return axios
|
||||
.post(
|
||||
`${baseURL}/ocr`,
|
||||
{
|
||||
model,
|
||||
image_limit: 0,
|
||||
include_image_base64: false,
|
||||
document: {
|
||||
type: documentType,
|
||||
[documentKey]: url,
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
},
|
||||
)
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error performing OCR:', error.message);
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if a value needs to be loaded from environment
|
||||
*/
|
||||
function needsEnvLoad(value: string): boolean {
|
||||
return envVarRegex.test(value) || !value.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the environment variable name for a config value
|
||||
*/
|
||||
function getEnvVarName(configValue: string, defaultName: string): string {
|
||||
if (!envVarRegex.test(configValue)) {
|
||||
return defaultName;
|
||||
}
|
||||
return extractVariableName(configValue) || defaultName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves a configuration value from either hardcoded or environment
|
||||
*/
|
||||
async function resolveConfigValue(
|
||||
configValue: string,
|
||||
defaultEnvName: string,
|
||||
authValues: Record<string, string | undefined>,
|
||||
defaultValue?: string,
|
||||
): Promise<string> {
|
||||
// If it's a hardcoded value (not env var and not empty), use it directly
|
||||
if (!needsEnvLoad(configValue)) {
|
||||
return configValue;
|
||||
}
|
||||
|
||||
// Otherwise, get from auth values
|
||||
const envVarName = getEnvVarName(configValue, defaultEnvName);
|
||||
return authValues[envVarName] || defaultValue || '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads authentication configuration from OCR config
|
||||
*/
|
||||
async function loadAuthConfig(context: OCRContext): Promise<AuthConfig> {
|
||||
const ocrConfig = context.req.app.locals?.ocr;
|
||||
const apiKeyConfig = ocrConfig?.apiKey || '';
|
||||
const baseURLConfig = ocrConfig?.baseURL || '';
|
||||
|
||||
if (!needsEnvLoad(apiKeyConfig) && !needsEnvLoad(baseURLConfig)) {
|
||||
return {
|
||||
apiKey: apiKeyConfig,
|
||||
baseURL: baseURLConfig,
|
||||
};
|
||||
}
|
||||
|
||||
const authFields: string[] = [];
|
||||
|
||||
if (needsEnvLoad(baseURLConfig)) {
|
||||
authFields.push(getEnvVarName(baseURLConfig, 'OCR_BASEURL'));
|
||||
}
|
||||
|
||||
if (needsEnvLoad(apiKeyConfig)) {
|
||||
authFields.push(getEnvVarName(apiKeyConfig, 'OCR_API_KEY'));
|
||||
}
|
||||
|
||||
const authValues = await context.loadAuthValues({
|
||||
userId: context.req.user?.id || '',
|
||||
authFields,
|
||||
optional: new Set(['OCR_BASEURL']),
|
||||
});
|
||||
|
||||
const apiKey = await resolveConfigValue(apiKeyConfig, 'OCR_API_KEY', authValues);
|
||||
const baseURL = await resolveConfigValue(
|
||||
baseURLConfig,
|
||||
'OCR_BASEURL',
|
||||
authValues,
|
||||
DEFAULT_MISTRAL_BASE_URL,
|
||||
);
|
||||
|
||||
return { apiKey, baseURL };
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the model configuration
|
||||
*/
|
||||
function getModelConfig(ocrConfig: TCustomConfig['ocr']): string {
|
||||
const modelConfig = ocrConfig?.mistralModel || '';
|
||||
|
||||
if (!modelConfig.trim()) {
|
||||
return DEFAULT_MISTRAL_MODEL;
|
||||
}
|
||||
|
||||
if (envVarRegex.test(modelConfig)) {
|
||||
return extractEnvVariable(modelConfig) || DEFAULT_MISTRAL_MODEL;
|
||||
}
|
||||
|
||||
return modelConfig.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines document type based on file
|
||||
*/
|
||||
function getDocumentType(file: Express.Multer.File): 'image_url' | 'document_url' {
|
||||
const mimetype = (file.mimetype || '').toLowerCase();
|
||||
const originalname = file.originalname || '';
|
||||
const isImage =
|
||||
mimetype.startsWith('image') || /\.(png|jpe?g|gif|bmp|webp|tiff?)$/i.test(originalname);
|
||||
|
||||
return isImage ? 'image_url' : 'document_url';
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes OCR result pages into aggregated text and images
|
||||
*/
|
||||
function processOCRResult(ocrResult: OCRResult): { text: string; images: string[] } {
|
||||
let aggregatedText = '';
|
||||
const images: string[] = [];
|
||||
|
||||
ocrResult.pages.forEach((page: OCRResultPage, index: number) => {
|
||||
if (ocrResult.pages.length > 1) {
|
||||
aggregatedText += `# PAGE ${index + 1}\n`;
|
||||
}
|
||||
|
||||
aggregatedText += page.markdown + '\n\n';
|
||||
|
||||
if (!page.images || page.images.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
page.images.forEach((image: OCRImage) => {
|
||||
if (image.image_base64) {
|
||||
images.push(image.image_base64);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return { text: aggregatedText, images };
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an error message for OCR operations
|
||||
*/
|
||||
function createOCRError(error: unknown, baseMessage: string): Error {
|
||||
const axiosError = error as AxiosError<MistralOCRError>;
|
||||
const detail = axiosError?.response?.data?.detail;
|
||||
const message = detail || baseMessage;
|
||||
|
||||
const responseMessage = axiosError?.response?.data?.message;
|
||||
const errorLog = logAxiosError({ error: axiosError, message });
|
||||
const fullMessage = responseMessage ? `${errorLog} - ${responseMessage}` : errorLog;
|
||||
|
||||
return new Error(fullMessage);
|
||||
}
|
||||
|
||||
/**
|
||||
* Uploads a file to the Mistral OCR API and processes the OCR result.
|
||||
*
|
||||
* @param params - The params object.
|
||||
* @param params.req - The request object from Express. It should have a `user` property with an `id`
|
||||
* representing the user
|
||||
* @param params.file - The file object, which is part of the request. The file object should
|
||||
* have a `mimetype` property that tells us the file type
|
||||
* @param params.loadAuthValues - Function to load authentication values
|
||||
* @returns - The result object containing the processed `text` and `images` (not currently used),
|
||||
* along with the `filename` and `bytes` properties.
|
||||
*/
|
||||
export const uploadMistralOCR = async (context: OCRContext): Promise<MistralOCRUploadResult> => {
|
||||
try {
|
||||
const { apiKey, baseURL } = await loadAuthConfig(context);
|
||||
const model = getModelConfig(context.req.app.locals?.ocr);
|
||||
|
||||
const mistralFile = await uploadDocumentToMistral({
|
||||
filePath: context.file.path,
|
||||
fileName: context.file.originalname,
|
||||
apiKey,
|
||||
baseURL,
|
||||
});
|
||||
|
||||
const signedUrlResponse = await getSignedUrl({
|
||||
apiKey,
|
||||
baseURL,
|
||||
fileId: mistralFile.id,
|
||||
});
|
||||
|
||||
const documentType = getDocumentType(context.file);
|
||||
const ocrResult = await performOCR({
|
||||
apiKey,
|
||||
baseURL,
|
||||
model,
|
||||
url: signedUrlResponse.url,
|
||||
documentType,
|
||||
});
|
||||
|
||||
// Process result
|
||||
const { text, images } = processOCRResult(ocrResult);
|
||||
|
||||
return {
|
||||
filename: context.file.originalname,
|
||||
bytes: text.length * 4,
|
||||
filepath: FileSources.mistral_ocr,
|
||||
text,
|
||||
images,
|
||||
};
|
||||
} catch (error) {
|
||||
throw createOCRError(error, 'Error uploading document to Mistral OCR API');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Use Azure Mistral OCR API to processe the OCR result.
|
||||
*
|
||||
* @param params - The params object.
|
||||
* @param params.req - The request object from Express. It should have a `user` property with an `id`
|
||||
* representing the user
|
||||
* @param params.file - The file object, which is part of the request. The file object should
|
||||
* have a `mimetype` property that tells us the file type
|
||||
* @param params.loadAuthValues - Function to load authentication values
|
||||
* @returns - The result object containing the processed `text` and `images` (not currently used),
|
||||
* along with the `filename` and `bytes` properties.
|
||||
*/
|
||||
export const uploadAzureMistralOCR = async (
|
||||
context: OCRContext,
|
||||
): Promise<MistralOCRUploadResult> => {
|
||||
try {
|
||||
const { apiKey, baseURL } = await loadAuthConfig(context);
|
||||
const model = getModelConfig(context.req.app.locals?.ocr);
|
||||
|
||||
const buffer = fs.readFileSync(context.file.path);
|
||||
const base64 = buffer.toString('base64');
|
||||
/** Uses actual mimetype of the file, 'image/jpeg' as fallback since it seems to be accepted regardless of mismatch */
|
||||
const base64Prefix = `data:${context.file.mimetype || 'image/jpeg'};base64,`;
|
||||
|
||||
const documentType = getDocumentType(context.file);
|
||||
const ocrResult = await performOCR({
|
||||
apiKey,
|
||||
baseURL,
|
||||
model,
|
||||
url: `${base64Prefix}${base64}`,
|
||||
documentType,
|
||||
});
|
||||
|
||||
const { text, images } = processOCRResult(ocrResult);
|
||||
|
||||
return {
|
||||
filename: context.file.originalname,
|
||||
bytes: text.length * 4,
|
||||
filepath: FileSources.azure_mistral_ocr,
|
||||
text,
|
||||
images,
|
||||
};
|
||||
} catch (error) {
|
||||
throw createOCRError(error, 'Error uploading document to Azure Mistral OCR API');
|
||||
}
|
||||
};
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import { FlowStateManager } from './manager';
|
||||
import { Keyv } from 'keyv';
|
||||
import { FlowStateManager } from './manager';
|
||||
import type { FlowState } from './types';
|
||||
|
||||
// Create a mock class without extending Keyv
|
||||
/** Mock class without extending Keyv */
|
||||
class MockKeyv {
|
||||
private store: Map<string, FlowState<string>>;
|
||||
|
||||
|
|
@ -1,28 +1,18 @@
|
|||
import { Keyv } from 'keyv';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { StoredDataNoRaw } from 'keyv';
|
||||
import type { Logger } from 'winston';
|
||||
import type { FlowState, FlowMetadata, FlowManagerOptions } from './types';
|
||||
|
||||
export class FlowStateManager<T = unknown> {
|
||||
private keyv: Keyv;
|
||||
private ttl: number;
|
||||
private logger: Logger;
|
||||
private intervals: Set<NodeJS.Timeout>;
|
||||
|
||||
private static getDefaultLogger(): Logger {
|
||||
return {
|
||||
error: console.error,
|
||||
warn: console.warn,
|
||||
info: console.info,
|
||||
debug: console.debug,
|
||||
} as Logger;
|
||||
}
|
||||
|
||||
constructor(store: Keyv, options?: FlowManagerOptions) {
|
||||
if (!options) {
|
||||
options = { ttl: 60000 * 3 };
|
||||
}
|
||||
const { ci = false, ttl, logger } = options;
|
||||
const { ci = false, ttl } = options;
|
||||
|
||||
if (!ci && !(store instanceof Keyv)) {
|
||||
throw new Error('Invalid store provided to FlowStateManager');
|
||||
|
|
@ -30,14 +20,13 @@ export class FlowStateManager<T = unknown> {
|
|||
|
||||
this.ttl = ttl;
|
||||
this.keyv = store;
|
||||
this.logger = logger || FlowStateManager.getDefaultLogger();
|
||||
this.intervals = new Set();
|
||||
this.setupCleanupHandlers();
|
||||
}
|
||||
|
||||
private setupCleanupHandlers() {
|
||||
const cleanup = () => {
|
||||
this.logger.info('Cleaning up FlowStateManager intervals...');
|
||||
logger.info('Cleaning up FlowStateManager intervals...');
|
||||
this.intervals.forEach((interval) => clearInterval(interval));
|
||||
this.intervals.clear();
|
||||
process.exit(0);
|
||||
|
|
@ -66,7 +55,7 @@ export class FlowStateManager<T = unknown> {
|
|||
|
||||
let existingState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
if (existingState) {
|
||||
this.logger.debug(`[${flowKey}] Flow already exists`);
|
||||
logger.debug(`[${flowKey}] Flow already exists`);
|
||||
return this.monitorFlow(flowKey, type, signal);
|
||||
}
|
||||
|
||||
|
|
@ -74,7 +63,7 @@ export class FlowStateManager<T = unknown> {
|
|||
|
||||
existingState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
if (existingState) {
|
||||
this.logger.debug(`[${flowKey}] Flow exists on 2nd check`);
|
||||
logger.debug(`[${flowKey}] Flow exists on 2nd check`);
|
||||
return this.monitorFlow(flowKey, type, signal);
|
||||
}
|
||||
|
||||
|
|
@ -85,7 +74,7 @@ export class FlowStateManager<T = unknown> {
|
|||
createdAt: Date.now(),
|
||||
};
|
||||
|
||||
this.logger.debug('Creating initial flow state:', flowKey);
|
||||
logger.debug('Creating initial flow state:', flowKey);
|
||||
await this.keyv.set(flowKey, initialState, this.ttl);
|
||||
return this.monitorFlow(flowKey, type, signal);
|
||||
}
|
||||
|
|
@ -102,7 +91,7 @@ export class FlowStateManager<T = unknown> {
|
|||
if (!flowState) {
|
||||
clearInterval(intervalId);
|
||||
this.intervals.delete(intervalId);
|
||||
this.logger.error(`[${flowKey}] Flow state not found`);
|
||||
logger.error(`[${flowKey}] Flow state not found`);
|
||||
reject(new Error(`${type} Flow state not found`));
|
||||
return;
|
||||
}
|
||||
|
|
@ -110,7 +99,7 @@ export class FlowStateManager<T = unknown> {
|
|||
if (signal?.aborted) {
|
||||
clearInterval(intervalId);
|
||||
this.intervals.delete(intervalId);
|
||||
this.logger.warn(`[${flowKey}] Flow aborted`);
|
||||
logger.warn(`[${flowKey}] Flow aborted`);
|
||||
const message = `${type} flow aborted`;
|
||||
await this.keyv.delete(flowKey);
|
||||
reject(new Error(message));
|
||||
|
|
@ -120,7 +109,7 @@ export class FlowStateManager<T = unknown> {
|
|||
if (flowState.status !== 'PENDING') {
|
||||
clearInterval(intervalId);
|
||||
this.intervals.delete(intervalId);
|
||||
this.logger.debug(`[${flowKey}] Flow completed`);
|
||||
logger.debug(`[${flowKey}] Flow completed`);
|
||||
|
||||
if (flowState.status === 'COMPLETED' && flowState.result !== undefined) {
|
||||
resolve(flowState.result);
|
||||
|
|
@ -135,17 +124,15 @@ export class FlowStateManager<T = unknown> {
|
|||
if (elapsedTime >= this.ttl) {
|
||||
clearInterval(intervalId);
|
||||
this.intervals.delete(intervalId);
|
||||
this.logger.error(
|
||||
logger.error(
|
||||
`[${flowKey}] Flow timed out | Elapsed time: ${elapsedTime} | TTL: ${this.ttl}`,
|
||||
);
|
||||
await this.keyv.delete(flowKey);
|
||||
reject(new Error(`${type} flow timed out`));
|
||||
}
|
||||
this.logger.debug(
|
||||
`[${flowKey}] Flow state elapsed time: ${elapsedTime}, checking again...`,
|
||||
);
|
||||
logger.debug(`[${flowKey}] Flow state elapsed time: ${elapsedTime}, checking again...`);
|
||||
} catch (error) {
|
||||
this.logger.error(`[${flowKey}] Error checking flow state:`, error);
|
||||
logger.error(`[${flowKey}] Error checking flow state:`, error);
|
||||
clearInterval(intervalId);
|
||||
this.intervals.delete(intervalId);
|
||||
reject(error);
|
||||
|
|
@ -224,7 +211,7 @@ export class FlowStateManager<T = unknown> {
|
|||
const flowKey = this.getFlowKey(flowId, type);
|
||||
let existingState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
if (existingState) {
|
||||
this.logger.debug(`[${flowKey}] Flow already exists`);
|
||||
logger.debug(`[${flowKey}] Flow already exists`);
|
||||
return this.monitorFlow(flowKey, type, signal);
|
||||
}
|
||||
|
||||
|
|
@ -232,7 +219,7 @@ export class FlowStateManager<T = unknown> {
|
|||
|
||||
existingState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
if (existingState) {
|
||||
this.logger.debug(`[${flowKey}] Flow exists on 2nd check`);
|
||||
logger.debug(`[${flowKey}] Flow exists on 2nd check`);
|
||||
return this.monitorFlow(flowKey, type, signal);
|
||||
}
|
||||
|
||||
|
|
@ -242,7 +229,7 @@ export class FlowStateManager<T = unknown> {
|
|||
metadata: {},
|
||||
createdAt: Date.now(),
|
||||
};
|
||||
this.logger.debug(`[${flowKey}] Creating initial flow state`);
|
||||
logger.debug(`[${flowKey}] Creating initial flow state`);
|
||||
await this.keyv.set(flowKey, initialState, this.ttl);
|
||||
|
||||
try {
|
||||
22
packages/api/src/index.ts
Normal file
22
packages/api/src/index.ts
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
/* MCP */
|
||||
export * from './mcp/manager';
|
||||
export * from './mcp/oauth';
|
||||
export * from './mcp/auth';
|
||||
/* Utilities */
|
||||
export * from './mcp/utils';
|
||||
export * from './utils';
|
||||
/* OAuth */
|
||||
export * from './oauth';
|
||||
/* Crypto */
|
||||
export * from './crypto';
|
||||
/* Flow */
|
||||
export * from './flow/manager';
|
||||
/* Agents */
|
||||
export * from './agents';
|
||||
/* Endpoints */
|
||||
export * from './endpoints';
|
||||
/* Files */
|
||||
export * from './files';
|
||||
/* types */
|
||||
export type * from './mcp/types';
|
||||
export type * from './flow/types';
|
||||
58
packages/api/src/mcp/auth.ts
Normal file
58
packages/api/src/mcp/auth.ts
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import type { PluginAuthMethods } from '@librechat/data-schemas';
|
||||
import type { GenericTool } from '@librechat/agents';
|
||||
import { getPluginAuthMap } from '~/agents/auth';
|
||||
import { mcpToolPattern } from './utils';
|
||||
|
||||
export async function getUserMCPAuthMap({
|
||||
userId,
|
||||
tools,
|
||||
appTools,
|
||||
findPluginAuthsByKeys,
|
||||
}: {
|
||||
userId: string;
|
||||
tools: GenericTool[] | undefined;
|
||||
appTools: Record<string, unknown>;
|
||||
findPluginAuthsByKeys: PluginAuthMethods['findPluginAuthsByKeys'];
|
||||
}) {
|
||||
if (!tools || tools.length === 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const uniqueMcpServers = new Set<string>();
|
||||
|
||||
for (const tool of tools) {
|
||||
const toolKey = tool.name;
|
||||
if (toolKey && appTools[toolKey] && mcpToolPattern.test(toolKey)) {
|
||||
const parts = toolKey.split(Constants.mcp_delimiter);
|
||||
const serverName = parts[parts.length - 1];
|
||||
uniqueMcpServers.add(`${Constants.mcp_prefix}${serverName}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (uniqueMcpServers.size === 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const mcpPluginKeysToFetch = Array.from(uniqueMcpServers);
|
||||
|
||||
let allMcpCustomUserVars: Record<string, Record<string, string>> = {};
|
||||
try {
|
||||
allMcpCustomUserVars = await getPluginAuthMap({
|
||||
userId,
|
||||
pluginKeys: mcpPluginKeysToFetch,
|
||||
throwError: false,
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[handleTools] Error batch fetching customUserVars for MCP tools (keys: ${mcpPluginKeysToFetch.join(
|
||||
', ',
|
||||
)}), user ${userId}: ${err instanceof Error ? err.message : 'Unknown error'}`,
|
||||
err,
|
||||
);
|
||||
}
|
||||
|
||||
return allMcpCustomUserVars;
|
||||
}
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import { EventEmitter } from 'events';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import {
|
||||
|
|
@ -10,8 +11,8 @@ import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk
|
|||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { Logger } from 'winston';
|
||||
import type * as t from './types/mcp.js';
|
||||
import type { MCPOAuthTokens } from './oauth/types';
|
||||
import type * as t from './types';
|
||||
|
||||
function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions {
|
||||
return 'command' in options;
|
||||
|
|
@ -67,28 +68,33 @@ export class MCPConnection extends EventEmitter {
|
|||
private isReconnecting = false;
|
||||
private isInitializing = false;
|
||||
private reconnectAttempts = 0;
|
||||
iconPath?: string;
|
||||
timeout?: number;
|
||||
private readonly userId?: string;
|
||||
private lastPingTime: number;
|
||||
private oauthTokens?: MCPOAuthTokens | null;
|
||||
private oauthRequired = false;
|
||||
iconPath?: string;
|
||||
timeout?: number;
|
||||
url?: string;
|
||||
|
||||
constructor(
|
||||
serverName: string,
|
||||
private readonly options: t.MCPOptions,
|
||||
private logger?: Logger,
|
||||
userId?: string,
|
||||
oauthTokens?: MCPOAuthTokens | null,
|
||||
) {
|
||||
super();
|
||||
this.serverName = serverName;
|
||||
this.logger = logger;
|
||||
this.userId = userId;
|
||||
this.iconPath = options.iconPath;
|
||||
this.timeout = options.timeout;
|
||||
this.lastPingTime = Date.now();
|
||||
if (oauthTokens) {
|
||||
this.oauthTokens = oauthTokens;
|
||||
}
|
||||
this.client = new Client(
|
||||
{
|
||||
name: 'librechat-mcp-client',
|
||||
version: '1.2.2',
|
||||
name: '@librechat/api-client',
|
||||
version: '1.2.3',
|
||||
},
|
||||
{
|
||||
capabilities: {},
|
||||
|
|
@ -107,11 +113,10 @@ export class MCPConnection extends EventEmitter {
|
|||
public static getInstance(
|
||||
serverName: string,
|
||||
options: t.MCPOptions,
|
||||
logger?: Logger,
|
||||
userId?: string,
|
||||
): MCPConnection {
|
||||
if (!MCPConnection.instance) {
|
||||
MCPConnection.instance = new MCPConnection(serverName, options, logger, userId);
|
||||
MCPConnection.instance = new MCPConnection(serverName, options, userId);
|
||||
}
|
||||
return MCPConnection.instance;
|
||||
}
|
||||
|
|
@ -129,7 +134,7 @@ export class MCPConnection extends EventEmitter {
|
|||
|
||||
private emitError(error: unknown, errorContext: string): void {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
this.logger?.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
|
||||
logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
|
||||
this.emit('error', new Error(`${errorContext}: ${errorMessage}`));
|
||||
}
|
||||
|
||||
|
|
@ -167,45 +172,52 @@ export class MCPConnection extends EventEmitter {
|
|||
if (!isWebSocketOptions(options)) {
|
||||
throw new Error('Invalid options for websocket transport.');
|
||||
}
|
||||
this.url = options.url;
|
||||
return new WebSocketClientTransport(new URL(options.url));
|
||||
|
||||
case 'sse': {
|
||||
if (!isSSEOptions(options)) {
|
||||
throw new Error('Invalid options for sse transport.');
|
||||
}
|
||||
this.url = options.url;
|
||||
const url = new URL(options.url);
|
||||
this.logger?.info(`${this.getLogPrefix()} Creating SSE transport: ${url.toString()}`);
|
||||
logger.info(`${this.getLogPrefix()} Creating SSE transport: ${url.toString()}`);
|
||||
const abortController = new AbortController();
|
||||
|
||||
/** Add OAuth token to headers if available */
|
||||
const headers = { ...options.headers };
|
||||
if (this.oauthTokens?.access_token) {
|
||||
headers['Authorization'] = `Bearer ${this.oauthTokens.access_token}`;
|
||||
}
|
||||
|
||||
const transport = new SSEClientTransport(url, {
|
||||
requestInit: {
|
||||
headers: options.headers,
|
||||
headers,
|
||||
signal: abortController.signal,
|
||||
},
|
||||
eventSourceInit: {
|
||||
fetch: (url, init) => {
|
||||
const headers = new Headers(Object.assign({}, init?.headers, options.headers));
|
||||
const fetchHeaders = new Headers(Object.assign({}, init?.headers, headers));
|
||||
return fetch(url, {
|
||||
...init,
|
||||
headers,
|
||||
headers: fetchHeaders,
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
transport.onclose = () => {
|
||||
this.logger?.info(`${this.getLogPrefix()} SSE transport closed`);
|
||||
logger.info(`${this.getLogPrefix()} SSE transport closed`);
|
||||
this.emit('connectionChange', 'disconnected');
|
||||
};
|
||||
|
||||
transport.onerror = (error) => {
|
||||
this.logger?.error(`${this.getLogPrefix()} SSE transport error:`, error);
|
||||
logger.error(`${this.getLogPrefix()} SSE transport error:`, error);
|
||||
this.emitError(error, 'SSE transport error:');
|
||||
};
|
||||
|
||||
transport.onmessage = (message) => {
|
||||
this.logger?.info(
|
||||
`${this.getLogPrefix()} Message received: ${JSON.stringify(message)}`,
|
||||
);
|
||||
logger.info(`${this.getLogPrefix()} Message received: ${JSON.stringify(message)}`);
|
||||
};
|
||||
|
||||
this.setupTransportErrorHandlers(transport);
|
||||
|
|
@ -216,33 +228,38 @@ export class MCPConnection extends EventEmitter {
|
|||
if (!isStreamableHTTPOptions(options)) {
|
||||
throw new Error('Invalid options for streamable-http transport.');
|
||||
}
|
||||
this.url = options.url;
|
||||
const url = new URL(options.url);
|
||||
this.logger?.info(
|
||||
logger.info(
|
||||
`${this.getLogPrefix()} Creating streamable-http transport: ${url.toString()}`,
|
||||
);
|
||||
const abortController = new AbortController();
|
||||
|
||||
// Add OAuth token to headers if available
|
||||
const headers = { ...options.headers };
|
||||
if (this.oauthTokens?.access_token) {
|
||||
headers['Authorization'] = `Bearer ${this.oauthTokens.access_token}`;
|
||||
}
|
||||
|
||||
const transport = new StreamableHTTPClientTransport(url, {
|
||||
requestInit: {
|
||||
headers: options.headers,
|
||||
headers,
|
||||
signal: abortController.signal,
|
||||
},
|
||||
});
|
||||
|
||||
transport.onclose = () => {
|
||||
this.logger?.info(`${this.getLogPrefix()} Streamable-http transport closed`);
|
||||
logger.info(`${this.getLogPrefix()} Streamable-http transport closed`);
|
||||
this.emit('connectionChange', 'disconnected');
|
||||
};
|
||||
|
||||
transport.onerror = (error: Error | unknown) => {
|
||||
this.logger?.error(`${this.getLogPrefix()} Streamable-http transport error:`, error);
|
||||
logger.error(`${this.getLogPrefix()} Streamable-http transport error:`, error);
|
||||
this.emitError(error, 'Streamable-http transport error:');
|
||||
};
|
||||
|
||||
transport.onmessage = (message: JSONRPCMessage) => {
|
||||
this.logger?.info(
|
||||
`${this.getLogPrefix()} Message received: ${JSON.stringify(message)}`,
|
||||
);
|
||||
logger.info(`${this.getLogPrefix()} Message received: ${JSON.stringify(message)}`);
|
||||
};
|
||||
|
||||
this.setupTransportErrorHandlers(transport);
|
||||
|
|
@ -271,17 +288,17 @@ export class MCPConnection extends EventEmitter {
|
|||
/**
|
||||
* // FOR DEBUGGING
|
||||
* // this.client.setRequestHandler(PingRequestSchema, async (request, extra) => {
|
||||
* // this.logger?.info(`[MCP][${this.serverName}] PingRequest: ${JSON.stringify(request)}`);
|
||||
* // logger.info(`[MCP][${this.serverName}] PingRequest: ${JSON.stringify(request)}`);
|
||||
* // if (getEventListeners && extra.signal) {
|
||||
* // const listenerCount = getEventListeners(extra.signal, 'abort').length;
|
||||
* // this.logger?.debug(`Signal has ${listenerCount} abort listeners`);
|
||||
* // logger.debug(`Signal has ${listenerCount} abort listeners`);
|
||||
* // }
|
||||
* // return {};
|
||||
* // });
|
||||
*/
|
||||
} else if (state === 'error' && !this.isReconnecting && !this.isInitializing) {
|
||||
this.handleReconnection().catch((error) => {
|
||||
this.logger?.error(`${this.getLogPrefix()} Reconnection handler failed:`, error);
|
||||
logger.error(`${this.getLogPrefix()} Reconnection handler failed:`, error);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
|
@ -290,7 +307,15 @@ export class MCPConnection extends EventEmitter {
|
|||
}
|
||||
|
||||
private async handleReconnection(): Promise<void> {
|
||||
if (this.isReconnecting || this.shouldStopReconnecting || this.isInitializing) {
|
||||
if (
|
||||
this.isReconnecting ||
|
||||
this.shouldStopReconnecting ||
|
||||
this.isInitializing ||
|
||||
this.oauthRequired
|
||||
) {
|
||||
if (this.oauthRequired) {
|
||||
logger.info(`${this.getLogPrefix()} OAuth required, skipping reconnection attempts`);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -305,7 +330,7 @@ export class MCPConnection extends EventEmitter {
|
|||
this.reconnectAttempts++;
|
||||
const delay = backoffDelay(this.reconnectAttempts);
|
||||
|
||||
this.logger?.info(
|
||||
logger.info(
|
||||
`${this.getLogPrefix()} Reconnecting ${this.reconnectAttempts}/${this.MAX_RECONNECT_ATTEMPTS} (delay: ${delay}ms)`,
|
||||
);
|
||||
|
||||
|
|
@ -316,13 +341,13 @@ export class MCPConnection extends EventEmitter {
|
|||
this.reconnectAttempts = 0;
|
||||
return;
|
||||
} catch (error) {
|
||||
this.logger?.error(`${this.getLogPrefix()} Reconnection attempt failed:`, error);
|
||||
logger.error(`${this.getLogPrefix()} Reconnection attempt failed:`, error);
|
||||
|
||||
if (
|
||||
this.reconnectAttempts === this.MAX_RECONNECT_ATTEMPTS ||
|
||||
(this.shouldStopReconnecting as boolean)
|
||||
) {
|
||||
this.logger?.error(`${this.getLogPrefix()} Stopping reconnection attempts`);
|
||||
logger.error(`${this.getLogPrefix()} Stopping reconnection attempts`);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -366,18 +391,21 @@ export class MCPConnection extends EventEmitter {
|
|||
await this.client.close();
|
||||
this.transport = null;
|
||||
} catch (error) {
|
||||
this.logger?.warn(`${this.getLogPrefix()} Error closing connection:`, error);
|
||||
logger.warn(`${this.getLogPrefix()} Error closing connection:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
this.transport = this.constructTransport(this.options);
|
||||
this.setupTransportDebugHandlers();
|
||||
|
||||
const connectTimeout = this.options.initTimeout ?? 10000;
|
||||
const connectTimeout = this.options.initTimeout ?? 120000;
|
||||
await Promise.race([
|
||||
this.client.connect(this.transport),
|
||||
new Promise((_resolve, reject) =>
|
||||
setTimeout(() => reject(new Error('Connection timeout')), connectTimeout),
|
||||
setTimeout(
|
||||
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
|
||||
connectTimeout,
|
||||
),
|
||||
),
|
||||
]);
|
||||
|
||||
|
|
@ -385,9 +413,85 @@ export class MCPConnection extends EventEmitter {
|
|||
this.emit('connectionChange', 'connected');
|
||||
this.reconnectAttempts = 0;
|
||||
} catch (error) {
|
||||
// Check if it's an OAuth authentication error
|
||||
if (this.isOAuthError(error)) {
|
||||
logger.warn(`${this.getLogPrefix()} OAuth authentication required`);
|
||||
this.oauthRequired = true;
|
||||
const serverUrl = this.url;
|
||||
logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`);
|
||||
|
||||
const oauthTimeout = this.options.initTimeout ?? 60000;
|
||||
/** Promise that will resolve when OAuth is handled */
|
||||
const oauthHandledPromise = new Promise<void>((resolve, reject) => {
|
||||
let timeoutId: NodeJS.Timeout | null = null;
|
||||
let oauthHandledListener: (() => void) | null = null;
|
||||
let oauthFailedListener: ((error: Error) => void) | null = null;
|
||||
|
||||
/** Cleanup function to remove listeners and clear timeout */
|
||||
const cleanup = () => {
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
if (oauthHandledListener) {
|
||||
this.off('oauthHandled', oauthHandledListener);
|
||||
}
|
||||
if (oauthFailedListener) {
|
||||
this.off('oauthFailed', oauthFailedListener);
|
||||
}
|
||||
};
|
||||
|
||||
// Success handler
|
||||
oauthHandledListener = () => {
|
||||
cleanup();
|
||||
resolve();
|
||||
};
|
||||
|
||||
// Failure handler
|
||||
oauthFailedListener = (error: Error) => {
|
||||
cleanup();
|
||||
reject(error);
|
||||
};
|
||||
|
||||
// Timeout handler
|
||||
timeoutId = setTimeout(() => {
|
||||
cleanup();
|
||||
reject(new Error(`OAuth handling timeout after ${oauthTimeout}ms`));
|
||||
}, oauthTimeout);
|
||||
|
||||
// Listen for both success and failure events
|
||||
this.once('oauthHandled', oauthHandledListener);
|
||||
this.once('oauthFailed', oauthFailedListener);
|
||||
});
|
||||
|
||||
// Emit the event
|
||||
this.emit('oauthRequired', {
|
||||
serverName: this.serverName,
|
||||
error,
|
||||
serverUrl,
|
||||
userId: this.userId,
|
||||
});
|
||||
|
||||
try {
|
||||
// Wait for OAuth to be handled
|
||||
await oauthHandledPromise;
|
||||
// Reset the oauthRequired flag
|
||||
this.oauthRequired = false;
|
||||
// Don't throw the error - just return so connection can be retried
|
||||
logger.info(
|
||||
`${this.getLogPrefix()} OAuth handled successfully, connection will be retried`,
|
||||
);
|
||||
return;
|
||||
} catch (oauthError) {
|
||||
// OAuth failed or timed out
|
||||
this.oauthRequired = false;
|
||||
logger.error(`${this.getLogPrefix()} OAuth handling failed:`, oauthError);
|
||||
// Re-throw the original authentication error
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
this.connectionState = 'error';
|
||||
this.emit('connectionChange', 'error');
|
||||
this.lastError = error instanceof Error ? error : new Error(String(error));
|
||||
throw error;
|
||||
} finally {
|
||||
this.connectPromise = null;
|
||||
|
|
@ -403,7 +507,7 @@ export class MCPConnection extends EventEmitter {
|
|||
}
|
||||
|
||||
this.transport.onmessage = (msg) => {
|
||||
this.logger?.debug(`${this.getLogPrefix()} Transport received: ${JSON.stringify(msg)}`);
|
||||
logger.debug(`${this.getLogPrefix()} Transport received: ${JSON.stringify(msg)}`);
|
||||
};
|
||||
|
||||
const originalSend = this.transport.send.bind(this.transport);
|
||||
|
|
@ -414,7 +518,7 @@ export class MCPConnection extends EventEmitter {
|
|||
}
|
||||
this.lastPingTime = Date.now();
|
||||
}
|
||||
this.logger?.debug(`${this.getLogPrefix()} Transport sending: ${JSON.stringify(msg)}`);
|
||||
logger.debug(`${this.getLogPrefix()} Transport sending: ${JSON.stringify(msg)}`);
|
||||
return originalSend(msg);
|
||||
};
|
||||
}
|
||||
|
|
@ -427,14 +531,24 @@ export class MCPConnection extends EventEmitter {
|
|||
throw new Error('Connection not established');
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger?.error(`${this.getLogPrefix()} Connection failed:`, error);
|
||||
logger.error(`${this.getLogPrefix()} Connection failed:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
private setupTransportErrorHandlers(transport: Transport): void {
|
||||
transport.onerror = (error) => {
|
||||
this.logger?.error(`${this.getLogPrefix()} Transport error:`, error);
|
||||
logger.error(`${this.getLogPrefix()} Transport error:`, error);
|
||||
|
||||
// Check if it's an OAuth authentication error
|
||||
if (error && typeof error === 'object' && 'code' in error) {
|
||||
const errorCode = (error as unknown as { code?: number }).code;
|
||||
if (errorCode === 401 || errorCode === 403) {
|
||||
logger.warn(`${this.getLogPrefix()} OAuth authentication error detected`);
|
||||
this.emit('oauthError', error);
|
||||
}
|
||||
}
|
||||
|
||||
this.emit('connectionChange', 'error');
|
||||
};
|
||||
}
|
||||
|
|
@ -562,22 +676,36 @@ export class MCPConnection extends EventEmitter {
|
|||
// }
|
||||
// }
|
||||
|
||||
// Public getters for state information
|
||||
public getConnectionState(): t.ConnectionState {
|
||||
return this.connectionState;
|
||||
}
|
||||
|
||||
public async isConnected(): Promise<boolean> {
|
||||
try {
|
||||
await this.client.ping();
|
||||
return this.connectionState === 'connected';
|
||||
} catch (error) {
|
||||
this.logger?.error(`${this.getLogPrefix()} Ping failed:`, error);
|
||||
logger.error(`${this.getLogPrefix()} Ping failed:`, error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public getLastError(): Error | null {
|
||||
return this.lastError;
|
||||
public setOAuthTokens(tokens: MCPOAuthTokens): void {
|
||||
this.oauthTokens = tokens;
|
||||
}
|
||||
|
||||
private isOAuthError(error: unknown): boolean {
|
||||
if (!error || typeof error !== 'object') {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check for SSE error with 401 status
|
||||
if ('message' in error && typeof error.message === 'string') {
|
||||
return error.message.includes('401') || error.message.includes('Non-200 status code (401)');
|
||||
}
|
||||
|
||||
// Check for error code
|
||||
if ('code' in error) {
|
||||
const code = (error as { code?: number }).code;
|
||||
return code === 401 || code === 403;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
9
packages/api/src/mcp/enum.ts
Normal file
9
packages/api/src/mcp/enum.ts
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
export enum CONSTANTS {
|
||||
mcp_delimiter = '_mcp_',
|
||||
/** System user ID for app-level OAuth tokens (all zeros ObjectId) */
|
||||
SYSTEM_USER_ID = '000000000000000000000000',
|
||||
}
|
||||
|
||||
export function isSystemUserId(userId?: string): boolean {
|
||||
return userId === CONSTANTS.SYSTEM_USER_ID;
|
||||
}
|
||||
1108
packages/api/src/mcp/manager.ts
Normal file
1108
packages/api/src/mcp/manager.ts
Normal file
File diff suppressed because it is too large
Load diff
603
packages/api/src/mcp/oauth/handler.ts
Normal file
603
packages/api/src/mcp/oauth/handler.ts
Normal file
|
|
@ -0,0 +1,603 @@
|
|||
import { randomBytes } from 'crypto';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import {
|
||||
discoverOAuthMetadata,
|
||||
registerClient,
|
||||
startAuthorization,
|
||||
exchangeAuthorization,
|
||||
discoverOAuthProtectedResourceMetadata,
|
||||
} from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
import { OAuthMetadataSchema } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import type { MCPOptions } from 'librechat-data-provider';
|
||||
import type { FlowStateManager } from '~/flow/manager';
|
||||
import type {
|
||||
OAuthClientInformation,
|
||||
OAuthProtectedResourceMetadata,
|
||||
MCPOAuthFlowMetadata,
|
||||
MCPOAuthTokens,
|
||||
OAuthMetadata,
|
||||
} from './types';
|
||||
|
||||
/** Type for the OAuth metadata from the SDK */
|
||||
type SDKOAuthMetadata = Parameters<typeof registerClient>[1]['metadata'];
|
||||
|
||||
export class MCPOAuthHandler {
|
||||
private static readonly FLOW_TYPE = 'mcp_oauth';
|
||||
private static readonly FLOW_TTL = 10 * 60 * 1000; // 10 minutes
|
||||
|
||||
/**
|
||||
* Discovers OAuth metadata from the server
|
||||
*/
|
||||
private static async discoverMetadata(serverUrl: string): Promise<{
|
||||
metadata: OAuthMetadata;
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata;
|
||||
authServerUrl: URL;
|
||||
}> {
|
||||
logger.debug(`[MCPOAuth] discoverMetadata called with serverUrl: ${serverUrl}`);
|
||||
|
||||
let authServerUrl = new URL(serverUrl);
|
||||
let resourceMetadata: OAuthProtectedResourceMetadata | undefined;
|
||||
|
||||
try {
|
||||
// Try to discover resource metadata first
|
||||
logger.debug(
|
||||
`[MCPOAuth] Attempting to discover protected resource metadata from ${serverUrl}`,
|
||||
);
|
||||
resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl);
|
||||
|
||||
if (resourceMetadata?.authorization_servers?.length) {
|
||||
authServerUrl = new URL(resourceMetadata.authorization_servers[0]);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Found authorization server from resource metadata: ${authServerUrl}`,
|
||||
);
|
||||
} else {
|
||||
logger.debug(`[MCPOAuth] No authorization servers found in resource metadata`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.debug('[MCPOAuth] Resource metadata discovery failed, continuing with server URL', {
|
||||
error,
|
||||
});
|
||||
}
|
||||
|
||||
// Discover OAuth metadata
|
||||
logger.debug(`[MCPOAuth] Discovering OAuth metadata from ${authServerUrl}`);
|
||||
const rawMetadata = await discoverOAuthMetadata(authServerUrl);
|
||||
|
||||
if (!rawMetadata) {
|
||||
logger.error(`[MCPOAuth] Failed to discover OAuth metadata from ${authServerUrl}`);
|
||||
throw new Error('Failed to discover OAuth metadata');
|
||||
}
|
||||
|
||||
logger.debug(`[MCPOAuth] OAuth metadata discovered successfully`);
|
||||
const metadata = await OAuthMetadataSchema.parseAsync(rawMetadata);
|
||||
|
||||
logger.debug(`[MCPOAuth] OAuth metadata parsed successfully`);
|
||||
return {
|
||||
metadata: metadata as unknown as OAuthMetadata,
|
||||
resourceMetadata,
|
||||
authServerUrl,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers an OAuth client dynamically
|
||||
*/
|
||||
private static async registerOAuthClient(
|
||||
serverUrl: string,
|
||||
metadata: OAuthMetadata,
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata,
|
||||
redirectUri?: string,
|
||||
): Promise<OAuthClientInformation> {
|
||||
logger.debug(`[MCPOAuth] Starting client registration for ${serverUrl}, server metadata:`, {
|
||||
grant_types_supported: metadata.grant_types_supported,
|
||||
response_types_supported: metadata.response_types_supported,
|
||||
token_endpoint_auth_methods_supported: metadata.token_endpoint_auth_methods_supported,
|
||||
scopes_supported: metadata.scopes_supported,
|
||||
});
|
||||
|
||||
/** Client metadata based on what the server supports */
|
||||
const clientMetadata = {
|
||||
client_name: 'LibreChat MCP Client',
|
||||
redirect_uris: [redirectUri || this.getDefaultRedirectUri()],
|
||||
grant_types: ['authorization_code'] as string[],
|
||||
response_types: ['code'] as string[],
|
||||
token_endpoint_auth_method: 'client_secret_basic',
|
||||
scope: undefined as string | undefined,
|
||||
};
|
||||
|
||||
const supportedGrantTypes = metadata.grant_types_supported || ['authorization_code'];
|
||||
const requestedGrantTypes = ['authorization_code'];
|
||||
|
||||
if (supportedGrantTypes.includes('refresh_token')) {
|
||||
requestedGrantTypes.push('refresh_token');
|
||||
logger.debug(
|
||||
`[MCPOAuth] Server ${serverUrl} supports \`refresh_token\` grant type, adding to request`,
|
||||
);
|
||||
} else {
|
||||
logger.debug(`[MCPOAuth] Server ${serverUrl} does not support \`refresh_token\` grant type`);
|
||||
}
|
||||
clientMetadata.grant_types = requestedGrantTypes;
|
||||
|
||||
clientMetadata.response_types = metadata.response_types_supported || ['code'];
|
||||
|
||||
if (metadata.token_endpoint_auth_methods_supported) {
|
||||
// Prefer client_secret_basic if supported, otherwise use the first supported method
|
||||
if (metadata.token_endpoint_auth_methods_supported.includes('client_secret_basic')) {
|
||||
clientMetadata.token_endpoint_auth_method = 'client_secret_basic';
|
||||
} else if (metadata.token_endpoint_auth_methods_supported.includes('client_secret_post')) {
|
||||
clientMetadata.token_endpoint_auth_method = 'client_secret_post';
|
||||
} else if (metadata.token_endpoint_auth_methods_supported.includes('none')) {
|
||||
clientMetadata.token_endpoint_auth_method = 'none';
|
||||
} else {
|
||||
clientMetadata.token_endpoint_auth_method =
|
||||
metadata.token_endpoint_auth_methods_supported[0];
|
||||
}
|
||||
}
|
||||
|
||||
const availableScopes = resourceMetadata?.scopes_supported || metadata.scopes_supported;
|
||||
if (availableScopes) {
|
||||
clientMetadata.scope = availableScopes.join(' ');
|
||||
}
|
||||
|
||||
logger.debug(`[MCPOAuth] Registering client for ${serverUrl} with metadata:`, clientMetadata);
|
||||
|
||||
const clientInfo = await registerClient(serverUrl, {
|
||||
metadata: metadata as unknown as SDKOAuthMetadata,
|
||||
clientMetadata,
|
||||
});
|
||||
|
||||
logger.debug(`[MCPOAuth] Client registered successfully for ${serverUrl}:`, {
|
||||
client_id: clientInfo.client_id,
|
||||
has_client_secret: !!clientInfo.client_secret,
|
||||
grant_types: clientInfo.grant_types,
|
||||
scope: clientInfo.scope,
|
||||
});
|
||||
|
||||
return clientInfo;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initiates the OAuth flow for an MCP server
|
||||
*/
|
||||
static async initiateOAuthFlow(
|
||||
serverName: string,
|
||||
serverUrl: string,
|
||||
userId: string,
|
||||
config: MCPOptions['oauth'] | undefined,
|
||||
): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> {
|
||||
logger.debug(`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${serverUrl}`);
|
||||
|
||||
const flowId = this.generateFlowId(userId, serverName);
|
||||
const state = this.generateState();
|
||||
|
||||
logger.debug(`[MCPOAuth] Generated flowId: ${flowId}, state: ${state}`);
|
||||
|
||||
try {
|
||||
// Check if we have pre-configured OAuth settings
|
||||
if (config?.authorization_url && config?.token_url && config?.client_id) {
|
||||
logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for ${serverName}`);
|
||||
/** Metadata based on pre-configured settings */
|
||||
const metadata: OAuthMetadata = {
|
||||
authorization_endpoint: config.authorization_url,
|
||||
token_endpoint: config.token_url,
|
||||
issuer: serverUrl,
|
||||
scopes_supported: config.scope?.split(' '),
|
||||
};
|
||||
|
||||
const clientInfo: OAuthClientInformation = {
|
||||
client_id: config.client_id,
|
||||
client_secret: config.client_secret,
|
||||
redirect_uris: [config.redirect_uri || this.getDefaultRedirectUri(serverName)],
|
||||
scope: config.scope,
|
||||
};
|
||||
|
||||
logger.debug(`[MCPOAuth] Starting authorization with pre-configured settings`);
|
||||
const { authorizationUrl, codeVerifier } = await startAuthorization(serverUrl, {
|
||||
metadata: metadata as unknown as SDKOAuthMetadata,
|
||||
clientInformation: clientInfo,
|
||||
redirectUrl: clientInfo.redirect_uris?.[0] || this.getDefaultRedirectUri(serverName),
|
||||
scope: config.scope,
|
||||
});
|
||||
|
||||
/** Add state parameter with flowId to the authorization URL */
|
||||
authorizationUrl.searchParams.set('state', flowId);
|
||||
logger.debug(`[MCPOAuth] Added state parameter to authorization URL`);
|
||||
|
||||
const flowMetadata: MCPOAuthFlowMetadata = {
|
||||
serverName,
|
||||
userId,
|
||||
serverUrl,
|
||||
state,
|
||||
codeVerifier,
|
||||
clientInfo,
|
||||
metadata,
|
||||
};
|
||||
|
||||
logger.debug(`[MCPOAuth] Authorization URL generated: ${authorizationUrl.toString()}`);
|
||||
return {
|
||||
authorizationUrl: authorizationUrl.toString(),
|
||||
flowId,
|
||||
flowMetadata,
|
||||
};
|
||||
}
|
||||
|
||||
logger.debug(`[MCPOAuth] Starting auto-discovery of OAuth metadata from ${serverUrl}`);
|
||||
const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata(serverUrl);
|
||||
|
||||
logger.debug(`[MCPOAuth] OAuth metadata discovered, auth server URL: ${authServerUrl}`);
|
||||
|
||||
/** Dynamic client registration based on the discovered metadata */
|
||||
const redirectUri = config?.redirect_uri || this.getDefaultRedirectUri(serverName);
|
||||
logger.debug(`[MCPOAuth] Registering OAuth client with redirect URI: ${redirectUri}`);
|
||||
|
||||
const clientInfo = await this.registerOAuthClient(
|
||||
authServerUrl.toString(),
|
||||
metadata,
|
||||
resourceMetadata,
|
||||
redirectUri,
|
||||
);
|
||||
|
||||
logger.debug(`[MCPOAuth] Client registered with ID: ${clientInfo.client_id}`);
|
||||
|
||||
/** Authorization Scope */
|
||||
const scope =
|
||||
config?.scope ||
|
||||
resourceMetadata?.scopes_supported?.join(' ') ||
|
||||
metadata.scopes_supported?.join(' ');
|
||||
|
||||
logger.debug(`[MCPOAuth] Starting authorization with scope: ${scope}`);
|
||||
|
||||
let authorizationUrl: URL;
|
||||
let codeVerifier: string;
|
||||
|
||||
try {
|
||||
logger.debug(`[MCPOAuth] Calling startAuthorization...`);
|
||||
const authResult = await startAuthorization(serverUrl, {
|
||||
metadata: metadata as unknown as SDKOAuthMetadata,
|
||||
clientInformation: clientInfo,
|
||||
redirectUrl: redirectUri,
|
||||
scope,
|
||||
});
|
||||
|
||||
authorizationUrl = authResult.authorizationUrl;
|
||||
codeVerifier = authResult.codeVerifier;
|
||||
|
||||
logger.debug(`[MCPOAuth] startAuthorization completed successfully`);
|
||||
logger.debug(`[MCPOAuth] Authorization URL: ${authorizationUrl.toString()}`);
|
||||
|
||||
/** Add state parameter with flowId to the authorization URL */
|
||||
authorizationUrl.searchParams.set('state', flowId);
|
||||
logger.debug(`[MCPOAuth] Added state parameter to authorization URL`);
|
||||
} catch (error) {
|
||||
logger.error(`[MCPOAuth] startAuthorization failed:`, error);
|
||||
throw error;
|
||||
}
|
||||
|
||||
const flowMetadata: MCPOAuthFlowMetadata = {
|
||||
serverName,
|
||||
userId,
|
||||
serverUrl,
|
||||
state,
|
||||
codeVerifier,
|
||||
clientInfo,
|
||||
metadata,
|
||||
resourceMetadata,
|
||||
};
|
||||
|
||||
logger.debug(
|
||||
`[MCPOAuth] Authorization URL generated for ${serverName}: ${authorizationUrl.toString()}`,
|
||||
);
|
||||
|
||||
const result = {
|
||||
authorizationUrl: authorizationUrl.toString(),
|
||||
flowId,
|
||||
flowMetadata,
|
||||
};
|
||||
|
||||
logger.debug(
|
||||
`[MCPOAuth] Returning from initiateOAuthFlow with result ${flowId} for ${serverName}`,
|
||||
result,
|
||||
);
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[MCPOAuth] Failed to initiate OAuth flow', { error, serverName, userId });
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Completes the OAuth flow by exchanging the authorization code for tokens
|
||||
*/
|
||||
static async completeOAuthFlow(
|
||||
flowId: string,
|
||||
authorizationCode: string,
|
||||
flowManager: FlowStateManager<MCPOAuthTokens>,
|
||||
): Promise<MCPOAuthTokens> {
|
||||
try {
|
||||
/** Flow state which contains our metadata */
|
||||
const flowState = await flowManager.getFlowState(flowId, this.FLOW_TYPE);
|
||||
if (!flowState) {
|
||||
throw new Error('OAuth flow not found');
|
||||
}
|
||||
|
||||
const flowMetadata = flowState.metadata as MCPOAuthFlowMetadata;
|
||||
if (!flowMetadata) {
|
||||
throw new Error('OAuth flow metadata not found');
|
||||
}
|
||||
|
||||
const metadata = flowMetadata;
|
||||
if (!metadata.metadata || !metadata.clientInfo || !metadata.codeVerifier) {
|
||||
throw new Error('Invalid flow metadata');
|
||||
}
|
||||
|
||||
const tokens = await exchangeAuthorization(metadata.serverUrl, {
|
||||
metadata: metadata.metadata as unknown as SDKOAuthMetadata,
|
||||
clientInformation: metadata.clientInfo,
|
||||
authorizationCode,
|
||||
codeVerifier: metadata.codeVerifier,
|
||||
redirectUri: metadata.clientInfo.redirect_uris?.[0] || this.getDefaultRedirectUri(),
|
||||
});
|
||||
|
||||
logger.debug('[MCPOAuth] Raw tokens from exchange:', {
|
||||
access_token: tokens.access_token ? '[REDACTED]' : undefined,
|
||||
refresh_token: tokens.refresh_token ? '[REDACTED]' : undefined,
|
||||
expires_in: tokens.expires_in,
|
||||
token_type: tokens.token_type,
|
||||
scope: tokens.scope,
|
||||
});
|
||||
|
||||
const mcpTokens: MCPOAuthTokens = {
|
||||
...tokens,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: tokens.expires_in ? Date.now() + tokens.expires_in * 1000 : undefined,
|
||||
};
|
||||
|
||||
/** Now complete the flow with the tokens */
|
||||
await flowManager.completeFlow(flowId, this.FLOW_TYPE, mcpTokens);
|
||||
|
||||
return mcpTokens;
|
||||
} catch (error) {
|
||||
logger.error('[MCPOAuth] Failed to complete OAuth flow', { error, flowId });
|
||||
await flowManager.failFlow(flowId, this.FLOW_TYPE, error as Error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the OAuth flow metadata
|
||||
*/
|
||||
static async getFlowState(
|
||||
flowId: string,
|
||||
flowManager: FlowStateManager<MCPOAuthTokens>,
|
||||
): Promise<MCPOAuthFlowMetadata | null> {
|
||||
const flowState = await flowManager.getFlowState(flowId, this.FLOW_TYPE);
|
||||
if (!flowState) {
|
||||
return null;
|
||||
}
|
||||
return flowState.metadata as MCPOAuthFlowMetadata;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a flow ID for the OAuth flow
|
||||
* @returns Consistent ID so concurrent requests share the same flow
|
||||
*/
|
||||
public static generateFlowId(userId: string, serverName: string): string {
|
||||
return `${userId}:${serverName}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a secure state parameter
|
||||
*/
|
||||
private static generateState(): string {
|
||||
return randomBytes(32).toString('base64url');
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the default redirect URI for a server
|
||||
*/
|
||||
private static getDefaultRedirectUri(serverName?: string): string {
|
||||
const baseUrl = process.env.DOMAIN_SERVER || 'http://localhost:3080';
|
||||
return serverName
|
||||
? `${baseUrl}/api/mcp/${serverName}/oauth/callback`
|
||||
: `${baseUrl}/api/mcp/oauth/callback`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes OAuth tokens using a refresh token
|
||||
*/
|
||||
static async refreshOAuthTokens(
|
||||
refreshToken: string,
|
||||
metadata: { serverName: string; serverUrl?: string; clientInfo?: OAuthClientInformation },
|
||||
config?: MCPOptions['oauth'],
|
||||
): Promise<MCPOAuthTokens> {
|
||||
logger.debug(`[MCPOAuth] Refreshing tokens for ${metadata.serverName}`);
|
||||
|
||||
try {
|
||||
/** If we have stored client information from the original flow, use that first */
|
||||
if (metadata.clientInfo?.client_id) {
|
||||
logger.debug(
|
||||
`[MCPOAuth] Using stored client information for token refresh for ${metadata.serverName}`,
|
||||
);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Client ID: ${metadata.clientInfo.client_id} for ${metadata.serverName}`,
|
||||
);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Has client secret: ${!!metadata.clientInfo.client_secret} for ${metadata.serverName}`,
|
||||
);
|
||||
logger.debug(`[MCPOAuth] Stored client info for ${metadata.serverName}:`, {
|
||||
client_id: metadata.clientInfo.client_id,
|
||||
has_client_secret: !!metadata.clientInfo.client_secret,
|
||||
grant_types: metadata.clientInfo.grant_types,
|
||||
scope: metadata.clientInfo.scope,
|
||||
});
|
||||
|
||||
/** Use the stored client information and metadata to determine the token URL */
|
||||
let tokenUrl: string;
|
||||
if (config?.token_url) {
|
||||
tokenUrl = config.token_url;
|
||||
} else if (!metadata.serverUrl) {
|
||||
throw new Error('No token URL available for refresh');
|
||||
} else {
|
||||
/** Auto-discover OAuth configuration for refresh */
|
||||
const { metadata: oauthMetadata } = await this.discoverMetadata(metadata.serverUrl);
|
||||
if (!oauthMetadata.token_endpoint) {
|
||||
throw new Error('No token endpoint found in OAuth metadata');
|
||||
}
|
||||
tokenUrl = oauthMetadata.token_endpoint;
|
||||
}
|
||||
|
||||
const body = new URLSearchParams({
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token: refreshToken,
|
||||
});
|
||||
|
||||
/** Add scope if available */
|
||||
if (metadata.clientInfo.scope) {
|
||||
body.append('scope', metadata.clientInfo.scope);
|
||||
}
|
||||
|
||||
const headers: HeadersInit = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
};
|
||||
|
||||
/** Use client_secret for authentication if available */
|
||||
if (metadata.clientInfo.client_secret) {
|
||||
const clientAuth = Buffer.from(
|
||||
`${metadata.clientInfo.client_id}:${metadata.clientInfo.client_secret}`,
|
||||
).toString('base64');
|
||||
headers['Authorization'] = `Basic ${clientAuth}`;
|
||||
} else {
|
||||
/** For public clients, client_id must be in the body */
|
||||
body.append('client_id', metadata.clientInfo.client_id);
|
||||
}
|
||||
|
||||
logger.debug(`[MCPOAuth] Refresh request to: ${tokenUrl}`, {
|
||||
body: body.toString(),
|
||||
headers,
|
||||
});
|
||||
|
||||
const response = await fetch(tokenUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Token refresh failed: ${response.status} ${response.statusText} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
|
||||
const tokens = await response.json();
|
||||
|
||||
return {
|
||||
...tokens,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: tokens.expires_in ? Date.now() + tokens.expires_in * 1000 : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
// Fallback: If we have pre-configured OAuth settings, use them
|
||||
if (config?.token_url && config?.client_id) {
|
||||
logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for token refresh`);
|
||||
|
||||
const tokenUrl = new URL(config.token_url);
|
||||
const clientAuth = config.client_secret
|
||||
? Buffer.from(`${config.client_id}:${config.client_secret}`).toString('base64')
|
||||
: null;
|
||||
|
||||
const body = new URLSearchParams({
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token: refreshToken,
|
||||
});
|
||||
|
||||
if (config.scope) {
|
||||
body.append('scope', config.scope);
|
||||
}
|
||||
|
||||
const headers: HeadersInit = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
};
|
||||
|
||||
if (clientAuth) {
|
||||
headers['Authorization'] = `Basic ${clientAuth}`;
|
||||
} else {
|
||||
// Use client_id in body for public clients
|
||||
body.append('client_id', config.client_id);
|
||||
}
|
||||
|
||||
const response = await fetch(tokenUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Token refresh failed: ${response.status} ${response.statusText} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
|
||||
const tokens = await response.json();
|
||||
|
||||
return {
|
||||
...tokens,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: tokens.expires_in ? Date.now() + tokens.expires_in * 1000 : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
/** For auto-discovered OAuth, we need the server URL */
|
||||
if (!metadata.serverUrl) {
|
||||
throw new Error('Server URL required for auto-discovered OAuth token refresh');
|
||||
}
|
||||
|
||||
/** Auto-discover OAuth configuration for refresh */
|
||||
const { metadata: oauthMetadata } = await this.discoverMetadata(metadata.serverUrl);
|
||||
|
||||
if (!oauthMetadata.token_endpoint) {
|
||||
throw new Error('No token endpoint found in OAuth metadata');
|
||||
}
|
||||
|
||||
const tokenUrl = new URL(oauthMetadata.token_endpoint);
|
||||
|
||||
const body = new URLSearchParams({
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token: refreshToken,
|
||||
});
|
||||
|
||||
const headers: HeadersInit = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
};
|
||||
|
||||
const response = await fetch(tokenUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Token refresh failed: ${response.status} ${response.statusText} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
|
||||
const tokens = await response.json();
|
||||
|
||||
return {
|
||||
...tokens,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: tokens.expires_in ? Date.now() + tokens.expires_in * 1000 : undefined,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`[MCPOAuth] Failed to refresh tokens for ${metadata.serverName}`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
3
packages/api/src/mcp/oauth/index.ts
Normal file
3
packages/api/src/mcp/oauth/index.ts
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
export * from './types';
|
||||
export * from './handler';
|
||||
export * from './tokens';
|
||||
382
packages/api/src/mcp/oauth/tokens.ts
Normal file
382
packages/api/src/mcp/oauth/tokens.ts
Normal file
|
|
@ -0,0 +1,382 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type { OAuthTokens, OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import type { TokenMethods, IToken } from '@librechat/data-schemas';
|
||||
import type { MCPOAuthTokens, ExtendedOAuthTokens } from './types';
|
||||
import { encryptV2, decryptV2 } from '~/crypto';
|
||||
import { isSystemUserId } from '~/mcp/enum';
|
||||
|
||||
interface StoreTokensParams {
|
||||
userId: string;
|
||||
serverName: string;
|
||||
tokens: OAuthTokens | ExtendedOAuthTokens | MCPOAuthTokens;
|
||||
createToken: TokenMethods['createToken'];
|
||||
updateToken?: TokenMethods['updateToken'];
|
||||
findToken?: TokenMethods['findToken'];
|
||||
clientInfo?: OAuthClientInformation;
|
||||
/** Optional: Pass existing token state to avoid duplicate DB calls */
|
||||
existingTokens?: {
|
||||
accessToken?: IToken | null;
|
||||
refreshToken?: IToken | null;
|
||||
clientInfoToken?: IToken | null;
|
||||
};
|
||||
}
|
||||
|
||||
interface GetTokensParams {
|
||||
userId: string;
|
||||
serverName: string;
|
||||
findToken: TokenMethods['findToken'];
|
||||
refreshTokens?: (
|
||||
refreshToken: string,
|
||||
metadata: { userId: string; serverName: string; identifier: string },
|
||||
) => Promise<MCPOAuthTokens>;
|
||||
createToken?: TokenMethods['createToken'];
|
||||
updateToken?: TokenMethods['updateToken'];
|
||||
}
|
||||
|
||||
export class MCPTokenStorage {
|
||||
static getLogPrefix(userId: string, serverName: string): string {
|
||||
return isSystemUserId(userId)
|
||||
? `[MCP][${serverName}]`
|
||||
: `[MCP][User: ${userId}][${serverName}]`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stores OAuth tokens for an MCP server
|
||||
*
|
||||
* @param params.existingTokens - Optional: Pass existing token state to avoid duplicate DB calls.
|
||||
* This is useful when refreshing tokens, as getTokens() already has the token state.
|
||||
*/
|
||||
static async storeTokens({
|
||||
userId,
|
||||
serverName,
|
||||
tokens,
|
||||
createToken,
|
||||
updateToken,
|
||||
findToken,
|
||||
clientInfo,
|
||||
existingTokens,
|
||||
}: StoreTokensParams): Promise<void> {
|
||||
const logPrefix = this.getLogPrefix(userId, serverName);
|
||||
|
||||
try {
|
||||
const identifier = `mcp:${serverName}`;
|
||||
|
||||
// Encrypt and store access token
|
||||
const encryptedAccessToken = await encryptV2(tokens.access_token);
|
||||
|
||||
logger.debug(
|
||||
`${logPrefix} Token expires_in: ${'expires_in' in tokens ? tokens.expires_in : 'N/A'}, expires_at: ${'expires_at' in tokens ? tokens.expires_at : 'N/A'}`,
|
||||
);
|
||||
|
||||
// Handle both expires_in and expires_at formats
|
||||
let accessTokenExpiry: Date;
|
||||
if ('expires_at' in tokens && tokens.expires_at) {
|
||||
/** MCPOAuthTokens format - already has calculated expiry */
|
||||
logger.debug(`${logPrefix} Using expires_at: ${tokens.expires_at}`);
|
||||
accessTokenExpiry = new Date(tokens.expires_at);
|
||||
} else if (tokens.expires_in) {
|
||||
/** Standard OAuthTokens format - calculate expiry */
|
||||
logger.debug(`${logPrefix} Using expires_in: ${tokens.expires_in}`);
|
||||
accessTokenExpiry = new Date(Date.now() + tokens.expires_in * 1000);
|
||||
} else {
|
||||
/** No expiry provided - default to 1 year */
|
||||
logger.debug(`${logPrefix} No expiry provided, using default`);
|
||||
accessTokenExpiry = new Date(Date.now() + 365 * 24 * 60 * 60 * 1000);
|
||||
}
|
||||
|
||||
logger.debug(`${logPrefix} Calculated expiry date: ${accessTokenExpiry.toISOString()}`);
|
||||
logger.debug(
|
||||
`${logPrefix} Date object: ${JSON.stringify({
|
||||
time: accessTokenExpiry.getTime(),
|
||||
valid: !isNaN(accessTokenExpiry.getTime()),
|
||||
iso: accessTokenExpiry.toISOString(),
|
||||
})}`,
|
||||
);
|
||||
|
||||
// Ensure the date is valid before passing to createToken
|
||||
if (isNaN(accessTokenExpiry.getTime())) {
|
||||
logger.error(`${logPrefix} Invalid expiry date calculated, using default`);
|
||||
accessTokenExpiry = new Date(Date.now() + 365 * 24 * 60 * 60 * 1000);
|
||||
}
|
||||
|
||||
// Calculate expiresIn (seconds from now)
|
||||
const expiresIn = Math.floor((accessTokenExpiry.getTime() - Date.now()) / 1000);
|
||||
|
||||
const accessTokenData = {
|
||||
userId,
|
||||
type: 'mcp_oauth',
|
||||
identifier,
|
||||
token: encryptedAccessToken,
|
||||
expiresIn: expiresIn > 0 ? expiresIn : 365 * 24 * 60 * 60, // Default to 1 year if negative
|
||||
};
|
||||
|
||||
// Check if token already exists and update if it does
|
||||
if (findToken && updateToken) {
|
||||
// Use provided existing token state if available, otherwise look it up
|
||||
const existingToken =
|
||||
existingTokens?.accessToken !== undefined
|
||||
? existingTokens.accessToken
|
||||
: await findToken({ userId, identifier });
|
||||
|
||||
if (existingToken) {
|
||||
await updateToken({ userId, identifier }, accessTokenData);
|
||||
logger.debug(`${logPrefix} Updated existing access token`);
|
||||
} else {
|
||||
await createToken(accessTokenData);
|
||||
logger.debug(`${logPrefix} Created new access token`);
|
||||
}
|
||||
} else {
|
||||
// Create new token if it's initial store or update methods not provided
|
||||
await createToken(accessTokenData);
|
||||
logger.debug(`${logPrefix} Created access token (no update methods available)`);
|
||||
}
|
||||
|
||||
// Store refresh token if available
|
||||
if (tokens.refresh_token) {
|
||||
const encryptedRefreshToken = await encryptV2(tokens.refresh_token);
|
||||
const extendedTokens = tokens as ExtendedOAuthTokens;
|
||||
const refreshTokenExpiry = extendedTokens.refresh_token_expires_in
|
||||
? new Date(Date.now() + extendedTokens.refresh_token_expires_in * 1000)
|
||||
: new Date(Date.now() + 365 * 24 * 60 * 60 * 1000); // Default to 1 year
|
||||
|
||||
/** Calculated expiresIn for refresh token */
|
||||
const refreshExpiresIn = Math.floor((refreshTokenExpiry.getTime() - Date.now()) / 1000);
|
||||
|
||||
const refreshTokenData = {
|
||||
userId,
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: `${identifier}:refresh`,
|
||||
token: encryptedRefreshToken,
|
||||
expiresIn: refreshExpiresIn > 0 ? refreshExpiresIn : 365 * 24 * 60 * 60,
|
||||
};
|
||||
|
||||
// Check if refresh token already exists and update if it does
|
||||
if (findToken && updateToken) {
|
||||
// Use provided existing token state if available, otherwise look it up
|
||||
const existingRefreshToken =
|
||||
existingTokens?.refreshToken !== undefined
|
||||
? existingTokens.refreshToken
|
||||
: await findToken({
|
||||
userId,
|
||||
identifier: `${identifier}:refresh`,
|
||||
});
|
||||
|
||||
if (existingRefreshToken) {
|
||||
await updateToken({ userId, identifier: `${identifier}:refresh` }, refreshTokenData);
|
||||
logger.debug(`${logPrefix} Updated existing refresh token`);
|
||||
} else {
|
||||
await createToken(refreshTokenData);
|
||||
logger.debug(`${logPrefix} Created new refresh token`);
|
||||
}
|
||||
} else {
|
||||
await createToken(refreshTokenData);
|
||||
logger.debug(`${logPrefix} Created refresh token (no update methods available)`);
|
||||
}
|
||||
}
|
||||
|
||||
/** Store client information if provided */
|
||||
if (clientInfo) {
|
||||
logger.debug(`${logPrefix} Storing client info:`, {
|
||||
client_id: clientInfo.client_id,
|
||||
has_client_secret: !!clientInfo.client_secret,
|
||||
});
|
||||
const encryptedClientInfo = await encryptV2(JSON.stringify(clientInfo));
|
||||
|
||||
const clientInfoData = {
|
||||
userId,
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: `${identifier}:client`,
|
||||
token: encryptedClientInfo,
|
||||
expiresIn: 365 * 24 * 60 * 60,
|
||||
};
|
||||
|
||||
// Check if client info already exists and update if it does
|
||||
if (findToken && updateToken) {
|
||||
// Use provided existing token state if available, otherwise look it up
|
||||
const existingClientInfo =
|
||||
existingTokens?.clientInfoToken !== undefined
|
||||
? existingTokens.clientInfoToken
|
||||
: await findToken({
|
||||
userId,
|
||||
identifier: `${identifier}:client`,
|
||||
});
|
||||
|
||||
if (existingClientInfo) {
|
||||
await updateToken({ userId, identifier: `${identifier}:client` }, clientInfoData);
|
||||
logger.debug(`${logPrefix} Updated existing client info`);
|
||||
} else {
|
||||
await createToken(clientInfoData);
|
||||
logger.debug(`${logPrefix} Created new client info`);
|
||||
}
|
||||
} else {
|
||||
await createToken(clientInfoData);
|
||||
logger.debug(`${logPrefix} Created client info (no update methods available)`);
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(`${logPrefix} Stored OAuth tokens`);
|
||||
} catch (error) {
|
||||
const logPrefix = this.getLogPrefix(userId, serverName);
|
||||
logger.error(`${logPrefix} Failed to store tokens`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves OAuth tokens for an MCP server
|
||||
*/
|
||||
static async getTokens({
|
||||
userId,
|
||||
serverName,
|
||||
findToken,
|
||||
createToken,
|
||||
updateToken,
|
||||
refreshTokens,
|
||||
}: GetTokensParams): Promise<MCPOAuthTokens | null> {
|
||||
const logPrefix = this.getLogPrefix(userId, serverName);
|
||||
|
||||
try {
|
||||
const identifier = `mcp:${serverName}`;
|
||||
|
||||
// Get access token
|
||||
const accessTokenData = await findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth',
|
||||
identifier,
|
||||
});
|
||||
|
||||
/** Check if access token is missing or expired */
|
||||
const isMissing = !accessTokenData;
|
||||
const isExpired = accessTokenData?.expiresAt && new Date() >= accessTokenData.expiresAt;
|
||||
|
||||
if (isMissing || isExpired) {
|
||||
logger.info(`${logPrefix} Access token ${isMissing ? 'missing' : 'expired'}`);
|
||||
|
||||
/** Refresh data if we have a refresh token and refresh function */
|
||||
const refreshTokenData = await findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: `${identifier}:refresh`,
|
||||
});
|
||||
|
||||
if (!refreshTokenData) {
|
||||
logger.info(
|
||||
`${logPrefix} Access token ${isMissing ? 'missing' : 'expired'} and no refresh token available`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!refreshTokens) {
|
||||
logger.warn(
|
||||
`${logPrefix} Access token ${isMissing ? 'missing' : 'expired'}, refresh token available but no \`refreshTokens\` provided`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!createToken) {
|
||||
logger.warn(
|
||||
`${logPrefix} Access token ${isMissing ? 'missing' : 'expired'}, refresh token available but no \`createToken\` function provided`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
logger.info(`${logPrefix} Attempting to refresh token`);
|
||||
const decryptedRefreshToken = await decryptV2(refreshTokenData.token);
|
||||
|
||||
/** Client information if available */
|
||||
let clientInfo;
|
||||
let clientInfoData;
|
||||
try {
|
||||
clientInfoData = await findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: `${identifier}:client`,
|
||||
});
|
||||
if (clientInfoData) {
|
||||
const decryptedClientInfo = await decryptV2(clientInfoData.token);
|
||||
clientInfo = JSON.parse(decryptedClientInfo);
|
||||
logger.debug(`${logPrefix} Retrieved client info:`, {
|
||||
client_id: clientInfo.client_id,
|
||||
has_client_secret: !!clientInfo.client_secret,
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
logger.debug(`${logPrefix} No client info found`);
|
||||
}
|
||||
|
||||
const metadata = {
|
||||
userId,
|
||||
serverName,
|
||||
identifier,
|
||||
clientInfo,
|
||||
};
|
||||
|
||||
const newTokens = await refreshTokens(decryptedRefreshToken, metadata);
|
||||
|
||||
// Store the refreshed tokens (handles both create and update)
|
||||
// Pass existing token state to avoid duplicate DB calls
|
||||
await this.storeTokens({
|
||||
userId,
|
||||
serverName,
|
||||
tokens: newTokens,
|
||||
createToken,
|
||||
updateToken,
|
||||
findToken,
|
||||
clientInfo,
|
||||
existingTokens: {
|
||||
accessToken: accessTokenData, // We know this is expired/missing
|
||||
refreshToken: refreshTokenData, // We already have this
|
||||
clientInfoToken: clientInfoData, // We already looked this up
|
||||
},
|
||||
});
|
||||
|
||||
logger.info(`${logPrefix} Successfully refreshed and stored OAuth tokens`);
|
||||
return newTokens;
|
||||
} catch (refreshError) {
|
||||
logger.error(`${logPrefix} Failed to refresh tokens`, refreshError);
|
||||
// Check if it's an unauthorized_client error (refresh not supported)
|
||||
const errorMessage =
|
||||
refreshError instanceof Error ? refreshError.message : String(refreshError);
|
||||
if (errorMessage.includes('unauthorized_client')) {
|
||||
logger.info(
|
||||
`${logPrefix} Server does not support refresh tokens for this client. New authentication required.`,
|
||||
);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
// If we reach here, access token should exist and be valid
|
||||
if (!accessTokenData) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const decryptedAccessToken = await decryptV2(accessTokenData.token);
|
||||
|
||||
/** Get refresh token if available */
|
||||
const refreshTokenData = await findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: `${identifier}:refresh`,
|
||||
});
|
||||
|
||||
const tokens: MCPOAuthTokens = {
|
||||
access_token: decryptedAccessToken,
|
||||
token_type: 'Bearer',
|
||||
obtained_at: accessTokenData.createdAt.getTime(),
|
||||
expires_at: accessTokenData.expiresAt?.getTime(),
|
||||
};
|
||||
|
||||
if (refreshTokenData) {
|
||||
tokens.refresh_token = await decryptV2(refreshTokenData.token);
|
||||
}
|
||||
|
||||
logger.debug(`${logPrefix} Loaded existing OAuth tokens from storage`);
|
||||
return tokens;
|
||||
} catch (error) {
|
||||
logger.error(`${logPrefix} Failed to retrieve tokens`, error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
98
packages/api/src/mcp/oauth/types.ts
Normal file
98
packages/api/src/mcp/oauth/types.ts
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import type { OAuthTokens } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import type { FlowMetadata } from '~/flow/types';
|
||||
|
||||
export interface OAuthMetadata {
|
||||
/** OAuth authorization endpoint */
|
||||
authorization_endpoint: string;
|
||||
/** OAuth token endpoint */
|
||||
token_endpoint: string;
|
||||
/** OAuth issuer */
|
||||
issuer?: string;
|
||||
/** Supported scopes */
|
||||
scopes_supported?: string[];
|
||||
/** Response types supported */
|
||||
response_types_supported?: string[];
|
||||
/** Grant types supported */
|
||||
grant_types_supported?: string[];
|
||||
/** Token endpoint auth methods supported */
|
||||
token_endpoint_auth_methods_supported?: string[];
|
||||
/** Code challenge methods supported */
|
||||
code_challenge_methods_supported?: string[];
|
||||
}
|
||||
|
||||
export interface OAuthProtectedResourceMetadata {
|
||||
/** Resource identifier */
|
||||
resource: string;
|
||||
/** Authorization servers */
|
||||
authorization_servers?: string[];
|
||||
/** Scopes supported by the resource */
|
||||
scopes_supported?: string[];
|
||||
}
|
||||
|
||||
export interface OAuthClientInformation {
|
||||
/** Client ID */
|
||||
client_id: string;
|
||||
/** Client secret (optional for public clients) */
|
||||
client_secret?: string;
|
||||
/** Client name */
|
||||
client_name?: string;
|
||||
/** Redirect URIs */
|
||||
redirect_uris?: string[];
|
||||
/** Grant types */
|
||||
grant_types?: string[];
|
||||
/** Response types */
|
||||
response_types?: string[];
|
||||
/** Scope */
|
||||
scope?: string;
|
||||
/** Token endpoint auth method */
|
||||
token_endpoint_auth_method?: string;
|
||||
}
|
||||
|
||||
export interface MCPOAuthState {
|
||||
/** Current step in the OAuth flow */
|
||||
step: 'discovery' | 'registration' | 'authorization' | 'token_exchange' | 'complete' | 'error';
|
||||
/** Server name */
|
||||
serverName: string;
|
||||
/** User ID */
|
||||
userId: string;
|
||||
/** OAuth metadata from discovery */
|
||||
metadata?: OAuthMetadata;
|
||||
/** Resource metadata */
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata;
|
||||
/** Client information */
|
||||
clientInfo?: OAuthClientInformation;
|
||||
/** Authorization URL */
|
||||
authorizationUrl?: string;
|
||||
/** Code verifier for PKCE */
|
||||
codeVerifier?: string;
|
||||
/** State parameter for OAuth flow */
|
||||
state?: string;
|
||||
/** Error information */
|
||||
error?: string;
|
||||
/** Timestamp */
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export interface MCPOAuthFlowMetadata extends FlowMetadata {
|
||||
serverName: string;
|
||||
userId: string;
|
||||
serverUrl: string;
|
||||
state: string;
|
||||
codeVerifier?: string;
|
||||
clientInfo?: OAuthClientInformation;
|
||||
metadata?: OAuthMetadata;
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata;
|
||||
}
|
||||
|
||||
export interface MCPOAuthTokens extends OAuthTokens {
|
||||
/** When the tokens were obtained */
|
||||
obtained_at: number;
|
||||
/** Calculated expiry time */
|
||||
expires_at?: number;
|
||||
}
|
||||
|
||||
/** Extended OAuth tokens that may include refresh token expiry */
|
||||
export interface ExtendedOAuthTokens extends OAuthTokens {
|
||||
/** Refresh token expiry in seconds (non-standard, some providers include this) */
|
||||
refresh_token_expires_in?: number;
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
import type * as t from './types/mcp';
|
||||
import type * as t from './types';
|
||||
const RECOGNIZED_PROVIDERS = new Set([
|
||||
'google',
|
||||
'anthropic',
|
||||
|
|
@ -8,14 +8,21 @@ import {
|
|||
StreamableHTTPOptionsSchema,
|
||||
} from 'librechat-data-provider';
|
||||
import type { JsonSchemaType, TPlugin } from 'librechat-data-provider';
|
||||
import { ToolSchema, ListToolsResultSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type * as t from '@modelcontextprotocol/sdk/types.js';
|
||||
|
||||
export type StdioOptions = z.infer<typeof StdioOptionsSchema>;
|
||||
export type WebSocketOptions = z.infer<typeof WebSocketOptionsSchema>;
|
||||
export type SSEOptions = z.infer<typeof SSEOptionsSchema>;
|
||||
export type StreamableHTTPOptions = z.infer<typeof StreamableHTTPOptionsSchema>;
|
||||
export type MCPOptions = z.infer<typeof MCPOptionsSchema>;
|
||||
export type MCPOptions = z.infer<typeof MCPOptionsSchema> & {
|
||||
customUserVars?: Record<
|
||||
string,
|
||||
{
|
||||
title: string;
|
||||
description: string;
|
||||
}
|
||||
>;
|
||||
};
|
||||
export type MCPServers = z.infer<typeof MCPServersSchema>;
|
||||
export interface MCPResource {
|
||||
uri: string;
|
||||
|
|
@ -45,8 +52,8 @@ export interface MCPPrompt {
|
|||
|
||||
export type ConnectionState = 'disconnected' | 'connecting' | 'connected' | 'error';
|
||||
|
||||
export type MCPTool = z.infer<typeof ToolSchema>;
|
||||
export type MCPToolListResponse = z.infer<typeof ListToolsResultSchema>;
|
||||
export type MCPTool = z.infer<typeof t.ToolSchema>;
|
||||
export type MCPToolListResponse = z.infer<typeof t.ListToolsResultSchema>;
|
||||
export type ToolContentPart = t.TextContent | t.ImageContent | t.EmbeddedResource | t.AudioContent;
|
||||
export type ImageContent = Extract<ToolContentPart, { type: 'image' }>;
|
||||
export type MCPToolCallResponse =
|
||||
|
|
@ -1,3 +1,6 @@
|
|||
import { Constants } from 'librechat-data-provider';
|
||||
|
||||
export const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`);
|
||||
/**
|
||||
* Normalizes a server name to match the pattern ^[a-zA-Z0-9_.-]+$
|
||||
* This is required for Azure OpenAI models with Tool Calling
|
||||
1
packages/api/src/oauth/index.ts
Normal file
1
packages/api/src/oauth/index.ts
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from './tokens';
|
||||
324
packages/api/src/oauth/tokens.ts
Normal file
324
packages/api/src/oauth/tokens.ts
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
import axios from 'axios';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { TokenExchangeMethodEnum } from 'librechat-data-provider';
|
||||
import type { TokenMethods } from '@librechat/data-schemas';
|
||||
import type { AxiosError } from 'axios';
|
||||
import { encryptV2, decryptV2 } from '~/crypto';
|
||||
import { logAxiosError } from '~/utils';
|
||||
|
||||
export function createHandleOAuthToken({
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
}: {
|
||||
findToken: TokenMethods['findToken'];
|
||||
updateToken: TokenMethods['updateToken'];
|
||||
createToken: TokenMethods['createToken'];
|
||||
}) {
|
||||
/**
|
||||
* Handles the OAuth token by creating or updating the token.
|
||||
* @param fields
|
||||
* @param fields.userId - The user's ID.
|
||||
* @param fields.token - The full token to store.
|
||||
* @param fields.identifier - Unique, alternative identifier for the token.
|
||||
* @param fields.expiresIn - The number of seconds until the token expires.
|
||||
* @param fields.metadata - Additional metadata to store with the token.
|
||||
* @param [fields.type="oauth"] - The type of token. Default is 'oauth'.
|
||||
*/
|
||||
return async function handleOAuthToken({
|
||||
token,
|
||||
userId,
|
||||
identifier,
|
||||
expiresIn,
|
||||
metadata,
|
||||
type = 'oauth',
|
||||
}: {
|
||||
token: string;
|
||||
userId: string;
|
||||
identifier: string;
|
||||
expiresIn?: number | string | null;
|
||||
metadata?: Record<string, unknown>;
|
||||
type?: string;
|
||||
}) {
|
||||
const encrypedToken = await encryptV2(token);
|
||||
let expiresInNumber = 3600;
|
||||
if (typeof expiresIn === 'number') {
|
||||
expiresInNumber = expiresIn;
|
||||
} else if (expiresIn != null) {
|
||||
expiresInNumber = parseInt(expiresIn, 10) || 3600;
|
||||
}
|
||||
const tokenData = {
|
||||
type,
|
||||
userId,
|
||||
metadata,
|
||||
identifier,
|
||||
token: encrypedToken,
|
||||
expiresIn: expiresInNumber,
|
||||
};
|
||||
|
||||
const existingToken = await findToken({ userId, identifier });
|
||||
if (existingToken) {
|
||||
return await updateToken({ identifier }, tokenData);
|
||||
} else {
|
||||
return await createToken(tokenData);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes the access tokens and stores them in the database.
|
||||
* @param tokenData
|
||||
* @param tokenData.access_token
|
||||
* @param tokenData.expires_in
|
||||
* @param [tokenData.refresh_token]
|
||||
* @param [tokenData.refresh_token_expires_in]
|
||||
* @param metadata
|
||||
* @param metadata.userId
|
||||
* @param metadata.identifier
|
||||
*/
|
||||
async function processAccessTokens(
|
||||
tokenData: {
|
||||
access_token: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
refresh_token_expires_in?: number;
|
||||
},
|
||||
{ userId, identifier }: { userId: string; identifier: string },
|
||||
{
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
}: {
|
||||
findToken: TokenMethods['findToken'];
|
||||
updateToken: TokenMethods['updateToken'];
|
||||
createToken: TokenMethods['createToken'];
|
||||
},
|
||||
) {
|
||||
const { access_token, expires_in = 3600, refresh_token, refresh_token_expires_in } = tokenData;
|
||||
if (!access_token) {
|
||||
logger.error('Access token not found: ', tokenData);
|
||||
throw new Error('Access token not found');
|
||||
}
|
||||
const handleOAuthToken = createHandleOAuthToken({
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
});
|
||||
await handleOAuthToken({
|
||||
identifier,
|
||||
token: access_token,
|
||||
expiresIn: expires_in,
|
||||
userId,
|
||||
});
|
||||
|
||||
if (refresh_token != null) {
|
||||
logger.debug('Processing refresh token');
|
||||
await handleOAuthToken({
|
||||
token: refresh_token,
|
||||
type: 'oauth_refresh',
|
||||
userId,
|
||||
identifier: `${identifier}:refresh`,
|
||||
expiresIn: refresh_token_expires_in ?? null,
|
||||
});
|
||||
}
|
||||
logger.debug('Access tokens processed');
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the access token using the refresh token.
|
||||
* @param fields
|
||||
* @param fields.userId - The ID of the user.
|
||||
* @param fields.client_url - The URL of the OAuth provider.
|
||||
* @param fields.identifier - The identifier for the token.
|
||||
* @param fields.refresh_token - The refresh token to use.
|
||||
* @param fields.token_exchange_method - The token exchange method ('default_post' or 'basic_auth_header').
|
||||
* @param fields.encrypted_oauth_client_id - The client ID for the OAuth provider.
|
||||
* @param fields.encrypted_oauth_client_secret - The client secret for the OAuth provider.
|
||||
*/
|
||||
export async function refreshAccessToken(
|
||||
{
|
||||
userId,
|
||||
client_url,
|
||||
identifier,
|
||||
refresh_token,
|
||||
token_exchange_method,
|
||||
encrypted_oauth_client_id,
|
||||
encrypted_oauth_client_secret,
|
||||
}: {
|
||||
userId: string;
|
||||
client_url: string;
|
||||
identifier: string;
|
||||
refresh_token: string;
|
||||
token_exchange_method: TokenExchangeMethodEnum;
|
||||
encrypted_oauth_client_id: string;
|
||||
encrypted_oauth_client_secret: string;
|
||||
},
|
||||
{
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
}: {
|
||||
findToken: TokenMethods['findToken'];
|
||||
updateToken: TokenMethods['updateToken'];
|
||||
createToken: TokenMethods['createToken'];
|
||||
},
|
||||
): Promise<{
|
||||
access_token: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
refresh_token_expires_in?: number;
|
||||
}> {
|
||||
try {
|
||||
const oauth_client_id = await decryptV2(encrypted_oauth_client_id);
|
||||
const oauth_client_secret = await decryptV2(encrypted_oauth_client_secret);
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
};
|
||||
|
||||
const params = new URLSearchParams({
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token,
|
||||
});
|
||||
|
||||
if (token_exchange_method === TokenExchangeMethodEnum.BasicAuthHeader) {
|
||||
const basicAuth = Buffer.from(`${oauth_client_id}:${oauth_client_secret}`).toString('base64');
|
||||
headers['Authorization'] = `Basic ${basicAuth}`;
|
||||
} else {
|
||||
params.append('client_id', oauth_client_id);
|
||||
params.append('client_secret', oauth_client_secret);
|
||||
}
|
||||
|
||||
const response = await axios({
|
||||
method: 'POST',
|
||||
url: client_url,
|
||||
headers,
|
||||
data: params.toString(),
|
||||
});
|
||||
await processAccessTokens(
|
||||
response.data,
|
||||
{
|
||||
userId,
|
||||
identifier,
|
||||
},
|
||||
{
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
},
|
||||
);
|
||||
logger.debug(`Access token refreshed successfully for ${identifier}`);
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
const message = 'Error refreshing OAuth tokens';
|
||||
throw new Error(
|
||||
logAxiosError({
|
||||
message,
|
||||
error: error as AxiosError,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles the OAuth callback and exchanges the authorization code for tokens.
|
||||
* @param {object} fields
|
||||
* @param {string} fields.code - The authorization code returned by the provider.
|
||||
* @param {string} fields.userId - The ID of the user.
|
||||
* @param {string} fields.identifier - The identifier for the token.
|
||||
* @param {string} fields.client_url - The URL of the OAuth provider.
|
||||
* @param {string} fields.redirect_uri - The redirect URI for the OAuth provider.
|
||||
* @param {string} fields.token_exchange_method - The token exchange method ('default_post' or 'basic_auth_header').
|
||||
* @param {string} fields.encrypted_oauth_client_id - The client ID for the OAuth provider.
|
||||
* @param {string} fields.encrypted_oauth_client_secret - The client secret for the OAuth provider.
|
||||
*/
|
||||
export async function getAccessToken(
|
||||
{
|
||||
code,
|
||||
userId,
|
||||
identifier,
|
||||
client_url,
|
||||
redirect_uri,
|
||||
token_exchange_method,
|
||||
encrypted_oauth_client_id,
|
||||
encrypted_oauth_client_secret,
|
||||
}: {
|
||||
code: string;
|
||||
userId: string;
|
||||
identifier: string;
|
||||
client_url: string;
|
||||
redirect_uri: string;
|
||||
token_exchange_method: TokenExchangeMethodEnum;
|
||||
encrypted_oauth_client_id: string;
|
||||
encrypted_oauth_client_secret: string;
|
||||
},
|
||||
{
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
}: {
|
||||
findToken: TokenMethods['findToken'];
|
||||
updateToken: TokenMethods['updateToken'];
|
||||
createToken: TokenMethods['createToken'];
|
||||
},
|
||||
): Promise<{
|
||||
access_token: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
refresh_token_expires_in?: number;
|
||||
}> {
|
||||
const oauth_client_id = await decryptV2(encrypted_oauth_client_id);
|
||||
const oauth_client_secret = await decryptV2(encrypted_oauth_client_secret);
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
};
|
||||
|
||||
const params = new URLSearchParams({
|
||||
code,
|
||||
grant_type: 'authorization_code',
|
||||
redirect_uri,
|
||||
});
|
||||
|
||||
if (token_exchange_method === TokenExchangeMethodEnum.BasicAuthHeader) {
|
||||
const basicAuth = Buffer.from(`${oauth_client_id}:${oauth_client_secret}`).toString('base64');
|
||||
headers['Authorization'] = `Basic ${basicAuth}`;
|
||||
} else {
|
||||
params.append('client_id', oauth_client_id);
|
||||
params.append('client_secret', oauth_client_secret);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios({
|
||||
method: 'POST',
|
||||
url: client_url,
|
||||
headers,
|
||||
data: params.toString(),
|
||||
});
|
||||
|
||||
await processAccessTokens(
|
||||
response.data,
|
||||
{
|
||||
userId,
|
||||
identifier,
|
||||
},
|
||||
{
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
},
|
||||
);
|
||||
logger.debug(`Access tokens successfully created for ${identifier}`);
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
const message = 'Error exchanging OAuth code';
|
||||
throw new Error(
|
||||
logAxiosError({
|
||||
message,
|
||||
error: error as AxiosError,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
19
packages/api/src/types/azure.ts
Normal file
19
packages/api/src/types/azure.ts
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
/**
|
||||
* Azure OpenAI configuration interface
|
||||
*/
|
||||
export interface AzureOptions {
|
||||
azureOpenAIApiKey?: string;
|
||||
azureOpenAIApiInstanceName?: string;
|
||||
azureOpenAIApiDeploymentName?: string;
|
||||
azureOpenAIApiVersion?: string;
|
||||
azureOpenAIBasePath?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Client with azure property for setting deployment name
|
||||
*/
|
||||
export interface GenericClient {
|
||||
azure: {
|
||||
azureOpenAIApiDeploymentName?: string;
|
||||
};
|
||||
}
|
||||
4
packages/api/src/types/events.ts
Normal file
4
packages/api/src/types/events.ts
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
export type ServerSentEvent = {
|
||||
data: string | Record<string, unknown>;
|
||||
event?: string;
|
||||
};
|
||||
5
packages/api/src/types/index.ts
Normal file
5
packages/api/src/types/index.ts
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
export * from './azure';
|
||||
export * from './events';
|
||||
export * from './mistral';
|
||||
export * from './openai';
|
||||
export * from './run';
|
||||
82
packages/api/src/types/mistral.ts
Normal file
82
packages/api/src/types/mistral.ts
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Mistral OCR API Types
|
||||
* Based on https://docs.mistral.ai/api/#tag/ocr/operation/ocr_v1_ocr_post
|
||||
*/
|
||||
|
||||
export interface MistralFileUploadResponse {
|
||||
id: string;
|
||||
object: string;
|
||||
bytes: number;
|
||||
created_at: number;
|
||||
filename: string;
|
||||
purpose: string;
|
||||
}
|
||||
|
||||
export interface MistralSignedUrlResponse {
|
||||
url: string;
|
||||
expires_at: number;
|
||||
}
|
||||
|
||||
export interface OCRImage {
|
||||
id: string;
|
||||
top_left_x: number;
|
||||
top_left_y: number;
|
||||
bottom_right_x: number;
|
||||
bottom_right_y: number;
|
||||
image_base64: string;
|
||||
image_annotation?: string;
|
||||
}
|
||||
|
||||
export interface PageDimensions {
|
||||
dpi: number;
|
||||
height: number;
|
||||
width: number;
|
||||
}
|
||||
|
||||
export interface OCRResultPage {
|
||||
index: number;
|
||||
markdown: string;
|
||||
images: OCRImage[];
|
||||
dimensions: PageDimensions;
|
||||
}
|
||||
|
||||
export interface OCRUsageInfo {
|
||||
pages_processed: number;
|
||||
doc_size_bytes: number;
|
||||
}
|
||||
|
||||
export interface OCRResult {
|
||||
pages: OCRResultPage[];
|
||||
model: string;
|
||||
document_annotation?: string | null;
|
||||
usage_info: OCRUsageInfo;
|
||||
}
|
||||
|
||||
export interface MistralOCRRequest {
|
||||
model: string;
|
||||
image_limit?: number;
|
||||
include_image_base64?: boolean;
|
||||
document: {
|
||||
type: 'document_url' | 'image_url';
|
||||
document_url?: string;
|
||||
image_url?: string;
|
||||
};
|
||||
}
|
||||
|
||||
export interface MistralOCRError {
|
||||
detail?: string;
|
||||
message?: string;
|
||||
error?: {
|
||||
message?: string;
|
||||
type?: string;
|
||||
code?: string;
|
||||
};
|
||||
}
|
||||
|
||||
export interface MistralOCRUploadResult {
|
||||
filename: string;
|
||||
bytes: number;
|
||||
filepath: string;
|
||||
text: string;
|
||||
images: string[];
|
||||
}
|
||||
97
packages/api/src/types/openai.ts
Normal file
97
packages/api/src/types/openai.ts
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
import { z } from 'zod';
|
||||
import { openAISchema, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { TEndpointOption, TAzureConfig, TEndpoint } from 'librechat-data-provider';
|
||||
import type { OpenAIClientOptions } from '@librechat/agents';
|
||||
import type { AzureOptions } from './azure';
|
||||
|
||||
export type OpenAIParameters = z.infer<typeof openAISchema>;
|
||||
|
||||
/**
|
||||
* Configuration options for the getLLMConfig function
|
||||
*/
|
||||
export interface LLMConfigOptions {
|
||||
modelOptions?: Partial<OpenAIParameters>;
|
||||
reverseProxyUrl?: string;
|
||||
defaultQuery?: Record<string, string | undefined>;
|
||||
headers?: Record<string, string>;
|
||||
proxy?: string;
|
||||
azure?: AzureOptions;
|
||||
streaming?: boolean;
|
||||
addParams?: Record<string, unknown>;
|
||||
dropParams?: string[];
|
||||
}
|
||||
|
||||
export type OpenAIConfiguration = OpenAIClientOptions['configuration'];
|
||||
|
||||
export type ClientOptions = OpenAIClientOptions & {
|
||||
include_reasoning?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Return type for getLLMConfig function
|
||||
*/
|
||||
export interface LLMConfigResult {
|
||||
llmConfig: ClientOptions;
|
||||
configOptions: OpenAIConfiguration;
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for user values retrieved from the database
|
||||
*/
|
||||
export interface UserKeyValues {
|
||||
apiKey?: string;
|
||||
baseURL?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Request interface with only the properties we need (avoids Express typing conflicts)
|
||||
*/
|
||||
export interface RequestData {
|
||||
user: {
|
||||
id: string;
|
||||
};
|
||||
body: {
|
||||
model?: string;
|
||||
endpoint?: string;
|
||||
key?: string;
|
||||
};
|
||||
app: {
|
||||
locals: {
|
||||
[EModelEndpoint.azureOpenAI]?: TAzureConfig;
|
||||
[EModelEndpoint.openAI]?: TEndpoint;
|
||||
all?: TEndpoint;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Function type for getting user key values
|
||||
*/
|
||||
export type GetUserKeyValuesFunction = (params: {
|
||||
userId: string;
|
||||
name: string;
|
||||
}) => Promise<UserKeyValues>;
|
||||
|
||||
/**
|
||||
* Function type for checking user key expiry
|
||||
*/
|
||||
export type CheckUserKeyExpiryFunction = (expiresAt: string, endpoint: string) => void;
|
||||
|
||||
/**
|
||||
* Parameters for the initializeOpenAI function
|
||||
*/
|
||||
export interface InitializeOpenAIOptionsParams {
|
||||
req: RequestData;
|
||||
overrideModel?: string;
|
||||
overrideEndpoint?: string;
|
||||
endpointOption: Partial<TEndpointOption>;
|
||||
getUserKeyValues: GetUserKeyValuesFunction;
|
||||
checkUserKeyExpiry: CheckUserKeyExpiryFunction;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extended LLM config result with stream rate handling
|
||||
*/
|
||||
export interface OpenAIOptionsResult extends LLMConfigResult {
|
||||
streamRate?: number;
|
||||
}
|
||||
10
packages/api/src/types/run.ts
Normal file
10
packages/api/src/types/run.ts
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
import type { AgentModelParameters, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { OpenAIConfiguration } from './openai';
|
||||
|
||||
export type RunLLMConfig = {
|
||||
provider: EModelEndpoint;
|
||||
streaming: boolean;
|
||||
streamUsage: boolean;
|
||||
usage?: boolean;
|
||||
configuration?: OpenAIConfiguration;
|
||||
} & AgentModelParameters;
|
||||
131
packages/api/src/utils/axios.spec.ts
Normal file
131
packages/api/src/utils/axios.spec.ts
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
import axios from 'axios';
|
||||
import { createAxiosInstance } from './axios';
|
||||
|
||||
jest.mock('axios', () => ({
|
||||
interceptors: {
|
||||
request: { use: jest.fn(), eject: jest.fn() },
|
||||
response: { use: jest.fn(), eject: jest.fn() },
|
||||
},
|
||||
create: jest.fn().mockReturnValue({
|
||||
defaults: {
|
||||
proxy: null,
|
||||
},
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
}),
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
reset: jest.fn().mockImplementation(function (this: {
|
||||
get: jest.Mock;
|
||||
post: jest.Mock;
|
||||
put: jest.Mock;
|
||||
delete: jest.Mock;
|
||||
create: jest.Mock;
|
||||
}) {
|
||||
this.get.mockClear();
|
||||
this.post.mockClear();
|
||||
this.put.mockClear();
|
||||
this.delete.mockClear();
|
||||
this.create.mockClear();
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('createAxiosInstance', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
jest.clearAllMocks();
|
||||
// Create a clean copy of process.env
|
||||
process.env = { ...originalEnv };
|
||||
// Default: no proxy
|
||||
delete process.env.proxy;
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
// Restore original process.env
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
test('creates an axios instance without proxy when no proxy env is set', () => {
|
||||
const instance = createAxiosInstance();
|
||||
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
expect(instance.defaults.proxy).toBeNull();
|
||||
});
|
||||
|
||||
test('configures proxy correctly with hostname and protocol', () => {
|
||||
process.env.proxy = 'http://example.com';
|
||||
|
||||
const instance = createAxiosInstance();
|
||||
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: 'example.com',
|
||||
protocol: 'http',
|
||||
});
|
||||
});
|
||||
|
||||
test('configures proxy correctly with hostname, protocol and port', () => {
|
||||
process.env.proxy = 'https://proxy.example.com:8080';
|
||||
|
||||
const instance = createAxiosInstance();
|
||||
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: 'proxy.example.com',
|
||||
protocol: 'https',
|
||||
port: 8080,
|
||||
});
|
||||
});
|
||||
|
||||
test('handles proxy URLs with authentication', () => {
|
||||
process.env.proxy = 'http://user:pass@proxy.example.com:3128';
|
||||
|
||||
const instance = createAxiosInstance();
|
||||
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: 'proxy.example.com',
|
||||
protocol: 'http',
|
||||
port: 3128,
|
||||
// Note: The current implementation doesn't handle auth - if needed, add this functionality
|
||||
});
|
||||
});
|
||||
|
||||
test('throws error when proxy URL is invalid', () => {
|
||||
process.env.proxy = 'invalid-url';
|
||||
|
||||
expect(() => createAxiosInstance()).toThrow('Invalid proxy URL');
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
// If you want to test the actual URL parsing more thoroughly
|
||||
test('handles edge case proxy URLs correctly', () => {
|
||||
// IPv6 address
|
||||
process.env.proxy = 'http://[::1]:8080';
|
||||
|
||||
let instance = createAxiosInstance();
|
||||
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: '::1',
|
||||
protocol: 'http',
|
||||
port: 8080,
|
||||
});
|
||||
|
||||
// URL with path (which should be ignored for proxy config)
|
||||
process.env.proxy = 'http://proxy.example.com:8080/some/path';
|
||||
|
||||
instance = createAxiosInstance();
|
||||
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: 'proxy.example.com',
|
||||
protocol: 'http',
|
||||
port: 8080,
|
||||
});
|
||||
});
|
||||
});
|
||||
77
packages/api/src/utils/axios.ts
Normal file
77
packages/api/src/utils/axios.ts
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import axios from 'axios';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { AxiosInstance, AxiosProxyConfig, AxiosError } from 'axios';
|
||||
|
||||
/**
|
||||
* Logs Axios errors based on the error object and a custom message.
|
||||
* @param options - The options object.
|
||||
* @param options.message - The custom message to be logged.
|
||||
* @param options.error - The Axios error object.
|
||||
* @returns The log message.
|
||||
*/
|
||||
export const logAxiosError = ({ message, error }: { message: string; error: AxiosError }) => {
|
||||
let logMessage = message;
|
||||
try {
|
||||
const stack = error.stack || 'No stack trace available';
|
||||
|
||||
if (error.response?.status) {
|
||||
const { status, headers, data } = error.response;
|
||||
logMessage = `${message} The server responded with status ${status}: ${error.message}`;
|
||||
logger.error(logMessage, {
|
||||
status,
|
||||
headers,
|
||||
data,
|
||||
stack,
|
||||
});
|
||||
} else if (error.request) {
|
||||
const { method, url } = error.config || {};
|
||||
logMessage = `${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`;
|
||||
logger.error(logMessage, {
|
||||
requestInfo: { method, url },
|
||||
stack,
|
||||
});
|
||||
} else if (error?.message?.includes("Cannot read properties of undefined (reading 'status')")) {
|
||||
logMessage = `${message} It appears the request timed out or was unsuccessful: ${error.message}`;
|
||||
logger.error(logMessage, { stack });
|
||||
} else {
|
||||
logMessage = `${message} An error occurred while setting up the request: ${error.message}`;
|
||||
logger.error(logMessage, { stack });
|
||||
}
|
||||
} catch (err: unknown) {
|
||||
logMessage = `Error in logAxiosError: ${(err as Error).message}`;
|
||||
logger.error(logMessage, { stack: (err as Error).stack || 'No stack trace available' });
|
||||
}
|
||||
return logMessage;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates and configures an Axios instance with optional proxy settings.
|
||||
|
||||
* @returns A configured Axios instance
|
||||
* @throws If there's an issue creating the Axios instance or parsing the proxy URL
|
||||
*/
|
||||
export function createAxiosInstance(): AxiosInstance {
|
||||
const instance = axios.create();
|
||||
|
||||
if (process.env.proxy) {
|
||||
try {
|
||||
const url = new URL(process.env.proxy);
|
||||
|
||||
const proxyConfig: Partial<AxiosProxyConfig> = {
|
||||
host: url.hostname.replace(/^\[|\]$/g, ''),
|
||||
protocol: url.protocol.replace(':', ''),
|
||||
};
|
||||
|
||||
if (url.port) {
|
||||
proxyConfig.port = parseInt(url.port, 10);
|
||||
}
|
||||
|
||||
instance.defaults.proxy = proxyConfig as AxiosProxyConfig;
|
||||
} catch (error) {
|
||||
console.error('Error parsing proxy URL:', error);
|
||||
throw new Error(`Invalid proxy URL: ${process.env.proxy}`);
|
||||
}
|
||||
}
|
||||
|
||||
return instance;
|
||||
}
|
||||
269
packages/api/src/utils/azure.spec.ts
Normal file
269
packages/api/src/utils/azure.spec.ts
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
import {
|
||||
genAzureChatCompletion,
|
||||
getAzureCredentials,
|
||||
constructAzureURL,
|
||||
sanitizeModelName,
|
||||
genAzureEndpoint,
|
||||
} from './azure';
|
||||
import type { GenericClient } from '~/types';
|
||||
|
||||
describe('sanitizeModelName', () => {
|
||||
test('removes periods from the model name', () => {
|
||||
const sanitized = sanitizeModelName('model.name');
|
||||
expect(sanitized).toBe('modelname');
|
||||
});
|
||||
|
||||
test('leaves model name unchanged if no periods are present', () => {
|
||||
const sanitized = sanitizeModelName('modelname');
|
||||
expect(sanitized).toBe('modelname');
|
||||
});
|
||||
});
|
||||
|
||||
describe('genAzureEndpoint', () => {
|
||||
test('generates correct endpoint URL', () => {
|
||||
const url = genAzureEndpoint({
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
});
|
||||
expect(url).toBe('https://instanceName.openai.azure.com/openai/deployments/deploymentName');
|
||||
});
|
||||
});
|
||||
|
||||
describe('genAzureChatCompletion', () => {
|
||||
// Test with both deployment name and model name provided
|
||||
test('prefers model name over deployment name when both are provided and feature enabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'modelName',
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/modelName/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with only deployment name provided
|
||||
test('uses deployment name when model name is not provided', () => {
|
||||
const url = genAzureChatCompletion({
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
});
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with only model name provided
|
||||
test('uses model name when deployment name is not provided and feature enabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'modelName',
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/modelName/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with neither deployment name nor model name provided
|
||||
test('throws error if neither deployment name nor model name is provided', () => {
|
||||
expect(() => {
|
||||
genAzureChatCompletion({
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
});
|
||||
}).toThrow(
|
||||
'Either a model name with the `AZURE_USE_MODEL_AS_DEPLOYMENT_NAME` setting or a deployment name must be provided if `AZURE_OPENAI_BASEURL` is omitted.',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with feature disabled but model name provided
|
||||
test('ignores model name and uses deployment name when feature is disabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'false';
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'modelName',
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with sanitized model name
|
||||
test('sanitizes model name when used in URL', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'model.name',
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/modelname/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with client parameter and model name
|
||||
test('updates client with sanitized model name when provided and feature enabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const clientMock = { azure: {} } as GenericClient;
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'model.name',
|
||||
clientMock,
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/modelname/chat/completions?api-version=v1',
|
||||
);
|
||||
expect(clientMock.azure.azureOpenAIApiDeploymentName).toBe('modelname');
|
||||
});
|
||||
|
||||
// Test with client parameter but without model name
|
||||
test('does not update client when model name is not provided', () => {
|
||||
const clientMock = { azure: {} } as GenericClient;
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
undefined,
|
||||
clientMock,
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
|
||||
);
|
||||
expect(clientMock.azure.azureOpenAIApiDeploymentName).toBeUndefined();
|
||||
});
|
||||
|
||||
// Test with client parameter and deployment name when feature is disabled
|
||||
test('does not update client when feature is disabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'false';
|
||||
const clientMock = { azure: {} } as GenericClient;
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'modelName',
|
||||
clientMock,
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
|
||||
);
|
||||
expect(clientMock.azure.azureOpenAIApiDeploymentName).toBeUndefined();
|
||||
});
|
||||
|
||||
// Reset environment variable after tests
|
||||
afterEach(() => {
|
||||
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAzureCredentials', () => {
|
||||
beforeEach(() => {
|
||||
process.env.AZURE_API_KEY = 'testApiKey';
|
||||
process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'instanceName';
|
||||
process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'deploymentName';
|
||||
process.env.AZURE_OPENAI_API_VERSION = 'v1';
|
||||
});
|
||||
|
||||
test('retrieves Azure OpenAI API credentials from environment variables', () => {
|
||||
const credentials = getAzureCredentials();
|
||||
expect(credentials).toEqual({
|
||||
azureOpenAIApiKey: 'testApiKey',
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('constructAzureURL', () => {
|
||||
test('replaces both placeholders when both properties are provided', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
azureOptions: {
|
||||
azureOpenAIApiInstanceName: 'instance1',
|
||||
azureOpenAIApiDeploymentName: 'deployment1',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://example.com/instance1/deployment1');
|
||||
});
|
||||
|
||||
test('replaces only INSTANCE_NAME when only azureOpenAIApiInstanceName is provided', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
azureOptions: {
|
||||
azureOpenAIApiInstanceName: 'instance2',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://example.com/instance2/');
|
||||
});
|
||||
|
||||
test('replaces only DEPLOYMENT_NAME when only azureOpenAIApiDeploymentName is provided', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
azureOptions: {
|
||||
azureOpenAIApiDeploymentName: 'deployment2',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://example.com//deployment2');
|
||||
});
|
||||
|
||||
test('does not replace any placeholders when azure object is empty', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
azureOptions: {},
|
||||
});
|
||||
expect(url).toBe('https://example.com//');
|
||||
});
|
||||
|
||||
test('returns baseURL as is when `azureOptions` object is not provided', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
});
|
||||
expect(url).toBe('https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}');
|
||||
});
|
||||
|
||||
test('returns baseURL as is when no placeholders are set', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/my_custom_instance/my_deployment',
|
||||
azureOptions: {
|
||||
azureOpenAIApiInstanceName: 'instance1',
|
||||
azureOpenAIApiDeploymentName: 'deployment1',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://example.com/my_custom_instance/my_deployment');
|
||||
});
|
||||
|
||||
test('returns regular Azure OpenAI baseURL with placeholders set', () => {
|
||||
const baseURL =
|
||||
'https://${INSTANCE_NAME}.openai.azure.com/openai/deployments/${DEPLOYMENT_NAME}';
|
||||
const url = constructAzureURL({
|
||||
baseURL,
|
||||
azureOptions: {
|
||||
azureOpenAIApiInstanceName: 'instance1',
|
||||
azureOpenAIApiDeploymentName: 'deployment1',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://instance1.openai.azure.com/openai/deployments/deployment1');
|
||||
});
|
||||
});
|
||||
120
packages/api/src/utils/azure.ts
Normal file
120
packages/api/src/utils/azure.ts
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
import { isEnabled } from './common';
|
||||
import type { AzureOptions, GenericClient } from '~/types';
|
||||
|
||||
/**
|
||||
* Sanitizes the model name to be used in the URL by removing or replacing disallowed characters.
|
||||
* @param modelName - The model name to be sanitized.
|
||||
* @returns The sanitized model name.
|
||||
*/
|
||||
export const sanitizeModelName = (modelName: string): string => {
|
||||
// Replace periods with empty strings and other disallowed characters as needed.
|
||||
return modelName.replace(/\./g, '');
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates the Azure OpenAI API endpoint URL.
|
||||
* @param params - The parameters object.
|
||||
* @param params.azureOpenAIApiInstanceName - The Azure OpenAI API instance name.
|
||||
* @param params.azureOpenAIApiDeploymentName - The Azure OpenAI API deployment name.
|
||||
* @returns The complete endpoint URL for the Azure OpenAI API.
|
||||
*/
|
||||
export const genAzureEndpoint = ({
|
||||
azureOpenAIApiInstanceName,
|
||||
azureOpenAIApiDeploymentName,
|
||||
}: {
|
||||
azureOpenAIApiInstanceName: string;
|
||||
azureOpenAIApiDeploymentName: string;
|
||||
}): string => {
|
||||
return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${azureOpenAIApiDeploymentName}`;
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates the Azure OpenAI API chat completion endpoint URL with the API version.
|
||||
* If both deploymentName and modelName are provided, modelName takes precedence.
|
||||
* @param azureConfig - The Azure configuration object.
|
||||
* @param azureConfig.azureOpenAIApiInstanceName - The Azure OpenAI API instance name.
|
||||
* @param azureConfig.azureOpenAIApiDeploymentName - The Azure OpenAI API deployment name (optional).
|
||||
* @param azureConfig.azureOpenAIApiVersion - The Azure OpenAI API version.
|
||||
* @param modelName - The model name to be included in the deployment name (optional).
|
||||
* @param client - The API Client class for optionally setting properties (optional).
|
||||
* @returns The complete chat completion endpoint URL for the Azure OpenAI API.
|
||||
* @throws Error if neither azureOpenAIApiDeploymentName nor modelName is provided.
|
||||
*/
|
||||
export const genAzureChatCompletion = (
|
||||
{
|
||||
azureOpenAIApiInstanceName,
|
||||
azureOpenAIApiDeploymentName,
|
||||
azureOpenAIApiVersion,
|
||||
}: {
|
||||
azureOpenAIApiInstanceName: string;
|
||||
azureOpenAIApiDeploymentName?: string;
|
||||
azureOpenAIApiVersion: string;
|
||||
},
|
||||
modelName?: string,
|
||||
client?: GenericClient,
|
||||
): string => {
|
||||
// Determine the deployment segment of the URL based on provided modelName or azureOpenAIApiDeploymentName
|
||||
let deploymentSegment: string;
|
||||
if (isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME) && modelName) {
|
||||
const sanitizedModelName = sanitizeModelName(modelName);
|
||||
deploymentSegment = sanitizedModelName;
|
||||
if (client && typeof client === 'object') {
|
||||
client.azure.azureOpenAIApiDeploymentName = sanitizedModelName;
|
||||
}
|
||||
} else if (azureOpenAIApiDeploymentName) {
|
||||
deploymentSegment = azureOpenAIApiDeploymentName;
|
||||
} else if (!process.env.AZURE_OPENAI_BASEURL) {
|
||||
throw new Error(
|
||||
'Either a model name with the `AZURE_USE_MODEL_AS_DEPLOYMENT_NAME` setting or a deployment name must be provided if `AZURE_OPENAI_BASEURL` is omitted.',
|
||||
);
|
||||
} else {
|
||||
deploymentSegment = '';
|
||||
}
|
||||
|
||||
return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${deploymentSegment}/chat/completions?api-version=${azureOpenAIApiVersion}`;
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves the Azure OpenAI API credentials from environment variables.
|
||||
* @returns An object containing the Azure OpenAI API credentials.
|
||||
*/
|
||||
export const getAzureCredentials = (): AzureOptions => {
|
||||
return {
|
||||
azureOpenAIApiKey: process.env.AZURE_API_KEY ?? process.env.AZURE_OPENAI_API_KEY,
|
||||
azureOpenAIApiInstanceName: process.env.AZURE_OPENAI_API_INSTANCE_NAME,
|
||||
azureOpenAIApiDeploymentName: process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME,
|
||||
azureOpenAIApiVersion: process.env.AZURE_OPENAI_API_VERSION,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Constructs a URL by replacing placeholders in the baseURL with values from the azure object.
|
||||
* It specifically looks for '${INSTANCE_NAME}' and '${DEPLOYMENT_NAME}' within the baseURL and replaces
|
||||
* them with 'azureOpenAIApiInstanceName' and 'azureOpenAIApiDeploymentName' from the azure object.
|
||||
* If the respective azure property is not provided, the placeholder is replaced with an empty string.
|
||||
*
|
||||
* @param params - The parameters object.
|
||||
* @param params.baseURL - The baseURL to inspect for replacement placeholders.
|
||||
* @param params.azureOptions - The azure options object containing the instance and deployment names.
|
||||
* @returns The complete baseURL with credentials injected for the Azure OpenAI API.
|
||||
*/
|
||||
export function constructAzureURL({
|
||||
baseURL,
|
||||
azureOptions,
|
||||
}: {
|
||||
baseURL: string;
|
||||
azureOptions?: AzureOptions;
|
||||
}): string {
|
||||
let finalURL = baseURL;
|
||||
|
||||
// Replace INSTANCE_NAME and DEPLOYMENT_NAME placeholders with actual values if available
|
||||
if (azureOptions) {
|
||||
finalURL = finalURL.replace('${INSTANCE_NAME}', azureOptions.azureOpenAIApiInstanceName ?? '');
|
||||
finalURL = finalURL.replace(
|
||||
'${DEPLOYMENT_NAME}',
|
||||
azureOptions.azureOpenAIApiDeploymentName ?? '',
|
||||
);
|
||||
}
|
||||
|
||||
return finalURL;
|
||||
}
|
||||
55
packages/api/src/utils/common.spec.ts
Normal file
55
packages/api/src/utils/common.spec.ts
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
/* eslint-disable @typescript-eslint/ban-ts-comment */
|
||||
import { isEnabled } from './common';
|
||||
|
||||
describe('isEnabled', () => {
|
||||
test('should return true when input is "true"', () => {
|
||||
expect(isEnabled('true')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true when input is "TRUE"', () => {
|
||||
expect(isEnabled('TRUE')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true when input is true', () => {
|
||||
expect(isEnabled(true)).toBe(true);
|
||||
});
|
||||
|
||||
test('should return false when input is "false"', () => {
|
||||
expect(isEnabled('false')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is false', () => {
|
||||
expect(isEnabled(false)).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is null', () => {
|
||||
expect(isEnabled(null)).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is undefined', () => {
|
||||
expect(isEnabled()).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is an empty string', () => {
|
||||
expect(isEnabled('')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is a whitespace string', () => {
|
||||
expect(isEnabled(' ')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is a number', () => {
|
||||
// @ts-expect-error
|
||||
expect(isEnabled(123)).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is an object', () => {
|
||||
// @ts-expect-error
|
||||
expect(isEnabled({})).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is an array', () => {
|
||||
// @ts-expect-error
|
||||
expect(isEnabled([])).toBe(false);
|
||||
});
|
||||
});
|
||||
48
packages/api/src/utils/common.ts
Normal file
48
packages/api/src/utils/common.ts
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Checks if the given value is truthy by being either the boolean `true` or a string
|
||||
* that case-insensitively matches 'true'.
|
||||
*
|
||||
* @param value - The value to check.
|
||||
* @returns Returns `true` if the value is the boolean `true` or a case-insensitive
|
||||
* match for the string 'true', otherwise returns `false`.
|
||||
* @example
|
||||
*
|
||||
* isEnabled("True"); // returns true
|
||||
* isEnabled("TRUE"); // returns true
|
||||
* isEnabled(true); // returns true
|
||||
* isEnabled("false"); // returns false
|
||||
* isEnabled(false); // returns false
|
||||
* isEnabled(null); // returns false
|
||||
* isEnabled(); // returns false
|
||||
*/
|
||||
export function isEnabled(value?: string | boolean | null | undefined): boolean {
|
||||
if (typeof value === 'boolean') {
|
||||
return value;
|
||||
}
|
||||
if (typeof value === 'string') {
|
||||
return value.toLowerCase().trim() === 'true';
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the provided value is 'user_provided'.
|
||||
*
|
||||
* @param value - The value to check.
|
||||
* @returns - Returns true if the value is 'user_provided', otherwise false.
|
||||
*/
|
||||
export const isUserProvided = (value?: string): boolean => value === 'user_provided';
|
||||
|
||||
/**
|
||||
* @param values
|
||||
*/
|
||||
export function optionalChainWithEmptyCheck(
|
||||
...values: (string | number | undefined)[]
|
||||
): string | number | undefined {
|
||||
for (const value of values) {
|
||||
if (value !== undefined && value !== null && value !== '') {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return values[values.length - 1];
|
||||
}
|
||||
16
packages/api/src/utils/events.ts
Normal file
16
packages/api/src/utils/events.ts
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
import type { Response as ServerResponse } from 'express';
|
||||
import type { ServerSentEvent } from '~/types';
|
||||
|
||||
/**
|
||||
* Sends message data in Server Sent Events format.
|
||||
* @param res - The server response.
|
||||
* @param event - The message event.
|
||||
* @param event.event - The type of event.
|
||||
* @param event.data - The message to be sent.
|
||||
*/
|
||||
export function sendEvent(res: ServerResponse, event: ServerSentEvent): void {
|
||||
if (typeof event.data === 'string' && event.data.length === 0) {
|
||||
return;
|
||||
}
|
||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
}
|
||||
115
packages/api/src/utils/files.spec.ts
Normal file
115
packages/api/src/utils/files.spec.ts
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
import { sanitizeFilename } from './files';
|
||||
|
||||
jest.mock('node:crypto', () => {
|
||||
const actualModule = jest.requireActual('node:crypto');
|
||||
return {
|
||||
...actualModule,
|
||||
randomBytes: jest.fn().mockReturnValue(Buffer.from('abc123', 'hex')),
|
||||
};
|
||||
});
|
||||
|
||||
describe('sanitizeFilename', () => {
|
||||
test('removes directory components (1/2)', () => {
|
||||
expect(sanitizeFilename('/path/to/file.txt')).toBe('file.txt');
|
||||
});
|
||||
|
||||
test('removes directory components (2/2)', () => {
|
||||
expect(sanitizeFilename('../../../../file.txt')).toBe('file.txt');
|
||||
});
|
||||
|
||||
test('replaces non-alphanumeric characters', () => {
|
||||
expect(sanitizeFilename('file name@#$.txt')).toBe('file_name___.txt');
|
||||
});
|
||||
|
||||
test('preserves dots and hyphens', () => {
|
||||
expect(sanitizeFilename('file-name.with.dots.txt')).toBe('file-name.with.dots.txt');
|
||||
});
|
||||
|
||||
test('prepends underscore to filenames starting with a dot', () => {
|
||||
expect(sanitizeFilename('.hiddenfile')).toBe('_.hiddenfile');
|
||||
});
|
||||
|
||||
test('truncates long filenames', () => {
|
||||
const longName = 'a'.repeat(300) + '.txt';
|
||||
const result = sanitizeFilename(longName);
|
||||
expect(result.length).toBe(255);
|
||||
expect(result).toMatch(/^a+-abc123\.txt$/);
|
||||
});
|
||||
|
||||
test('handles filenames with no extension', () => {
|
||||
const longName = 'a'.repeat(300);
|
||||
const result = sanitizeFilename(longName);
|
||||
expect(result.length).toBe(255);
|
||||
expect(result).toMatch(/^a+-abc123$/);
|
||||
});
|
||||
|
||||
test('handles empty input', () => {
|
||||
expect(sanitizeFilename('')).toBe('_');
|
||||
});
|
||||
|
||||
test('handles input with only special characters', () => {
|
||||
expect(sanitizeFilename('@#$%^&*')).toBe('_______');
|
||||
});
|
||||
});
|
||||
|
||||
describe('sanitizeFilename with real crypto', () => {
|
||||
// Temporarily unmock crypto for these tests
|
||||
beforeAll(() => {
|
||||
jest.resetModules();
|
||||
jest.unmock('node:crypto');
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
jest.resetModules();
|
||||
jest.mock('node:crypto', () => {
|
||||
const actualModule = jest.requireActual('node:crypto');
|
||||
return {
|
||||
...actualModule,
|
||||
randomBytes: jest.fn().mockReturnValue(Buffer.from('abc123', 'hex')),
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
test('truncates long filenames with real crypto', async () => {
|
||||
const { sanitizeFilename: realSanitizeFilename } = await import('./files');
|
||||
const longName = 'b'.repeat(300) + '.pdf';
|
||||
const result = realSanitizeFilename(longName);
|
||||
|
||||
expect(result.length).toBe(255);
|
||||
expect(result).toMatch(/^b+-[a-f0-9]{6}\.pdf$/);
|
||||
expect(result.endsWith('.pdf')).toBe(true);
|
||||
});
|
||||
|
||||
test('handles filenames with no extension with real crypto', async () => {
|
||||
const { sanitizeFilename: realSanitizeFilename } = await import('./files');
|
||||
const longName = 'c'.repeat(300);
|
||||
const result = realSanitizeFilename(longName);
|
||||
|
||||
expect(result.length).toBe(255);
|
||||
expect(result).toMatch(/^c+-[a-f0-9]{6}$/);
|
||||
expect(result).not.toContain('.');
|
||||
});
|
||||
|
||||
test('generates unique suffixes for identical long filenames', async () => {
|
||||
const { sanitizeFilename: realSanitizeFilename } = await import('./files');
|
||||
const longName = 'd'.repeat(300) + '.doc';
|
||||
const result1 = realSanitizeFilename(longName);
|
||||
const result2 = realSanitizeFilename(longName);
|
||||
|
||||
expect(result1.length).toBe(255);
|
||||
expect(result2.length).toBe(255);
|
||||
expect(result1).not.toBe(result2); // Should be different due to random suffix
|
||||
expect(result1.endsWith('.doc')).toBe(true);
|
||||
expect(result2.endsWith('.doc')).toBe(true);
|
||||
});
|
||||
|
||||
test('real crypto produces valid hex strings', async () => {
|
||||
const { sanitizeFilename: realSanitizeFilename } = await import('./files');
|
||||
const longName = 'test'.repeat(100) + '.txt';
|
||||
const result = realSanitizeFilename(longName);
|
||||
|
||||
const hexMatch = result.match(/-([a-f0-9]{6})\.txt$/);
|
||||
expect(hexMatch).toBeTruthy();
|
||||
expect(hexMatch![1]).toMatch(/^[a-f0-9]{6}$/);
|
||||
});
|
||||
});
|
||||
33
packages/api/src/utils/files.ts
Normal file
33
packages/api/src/utils/files.ts
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import path from 'path';
|
||||
import crypto from 'node:crypto';
|
||||
|
||||
/**
|
||||
* Sanitize a filename by removing any directory components, replacing non-alphanumeric characters
|
||||
* @param inputName
|
||||
*/
|
||||
export function sanitizeFilename(inputName: string): string {
|
||||
// Remove any directory components
|
||||
let name = path.basename(inputName);
|
||||
|
||||
// Replace any non-alphanumeric characters except for '.' and '-'
|
||||
name = name.replace(/[^a-zA-Z0-9.-]/g, '_');
|
||||
|
||||
// Ensure the name doesn't start with a dot (hidden file in Unix-like systems)
|
||||
if (name.startsWith('.') || name === '') {
|
||||
name = '_' + name;
|
||||
}
|
||||
|
||||
// Limit the length of the filename
|
||||
const MAX_LENGTH = 255;
|
||||
if (name.length > MAX_LENGTH) {
|
||||
const ext = path.extname(name);
|
||||
const nameWithoutExt = path.basename(name, ext);
|
||||
name =
|
||||
nameWithoutExt.slice(0, MAX_LENGTH - ext.length - 7) +
|
||||
'-' +
|
||||
crypto.randomBytes(3).toString('hex') +
|
||||
ext;
|
||||
}
|
||||
|
||||
return name;
|
||||
}
|
||||
75
packages/api/src/utils/generators.ts
Normal file
75
packages/api/src/utils/generators.ts
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
import fetch from 'node-fetch';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { GraphEvents, sleep } from '@librechat/agents';
|
||||
import type { Response as ServerResponse } from 'express';
|
||||
import type { ServerSentEvent } from '~/types';
|
||||
import { sendEvent } from './events';
|
||||
|
||||
/**
|
||||
* Makes a function to make HTTP request and logs the process.
|
||||
* @param params
|
||||
* @param params.directEndpoint - Whether to use a direct endpoint.
|
||||
* @param params.reverseProxyUrl - The reverse proxy URL to use for the request.
|
||||
* @returns A promise that resolves to the response of the fetch request.
|
||||
*/
|
||||
export function createFetch({
|
||||
directEndpoint = false,
|
||||
reverseProxyUrl = '',
|
||||
}: {
|
||||
directEndpoint?: boolean;
|
||||
reverseProxyUrl?: string;
|
||||
}) {
|
||||
/**
|
||||
* Makes an HTTP request and logs the process.
|
||||
* @param url - The URL to make the request to. Can be a string or a Request object.
|
||||
* @param init - Optional init options for the request.
|
||||
* @returns A promise that resolves to the response of the fetch request.
|
||||
*/
|
||||
return async function (
|
||||
_url: fetch.RequestInfo,
|
||||
init: fetch.RequestInit,
|
||||
): Promise<fetch.Response> {
|
||||
let url = _url;
|
||||
if (directEndpoint) {
|
||||
url = reverseProxyUrl;
|
||||
}
|
||||
logger.debug(`Making request to ${url}`);
|
||||
if (typeof Bun !== 'undefined') {
|
||||
return await fetch(url, init);
|
||||
}
|
||||
return await fetch(url, init);
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates event handlers for stream events that don't capture client references
|
||||
* @param res - The response object to send events to
|
||||
* @returns Object containing handler functions
|
||||
*/
|
||||
export function createStreamEventHandlers(res: ServerResponse) {
|
||||
return {
|
||||
[GraphEvents.ON_RUN_STEP]: function (event: ServerSentEvent) {
|
||||
if (res) {
|
||||
sendEvent(res, event);
|
||||
}
|
||||
},
|
||||
[GraphEvents.ON_MESSAGE_DELTA]: function (event: ServerSentEvent) {
|
||||
if (res) {
|
||||
sendEvent(res, event);
|
||||
}
|
||||
},
|
||||
[GraphEvents.ON_REASONING_DELTA]: function (event: ServerSentEvent) {
|
||||
if (res) {
|
||||
sendEvent(res, event);
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function createHandleLLMNewToken(streamRate: number) {
|
||||
return async function () {
|
||||
if (streamRate) {
|
||||
await sleep(streamRate);
|
||||
}
|
||||
};
|
||||
}
|
||||
8
packages/api/src/utils/index.ts
Normal file
8
packages/api/src/utils/index.ts
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
export * from './axios';
|
||||
export * from './azure';
|
||||
export * from './common';
|
||||
export * from './events';
|
||||
export * from './files';
|
||||
export * from './generators';
|
||||
export * from './openid';
|
||||
export { default as Tokenizer } from './tokenizer';
|
||||
51
packages/api/src/utils/openid.ts
Normal file
51
packages/api/src/utils/openid.ts
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Helper function to safely log sensitive data when debug mode is enabled
|
||||
* @param obj - Object to stringify
|
||||
* @param maxLength - Maximum length of the stringified output
|
||||
* @returns Stringified object with sensitive data masked
|
||||
*/
|
||||
export function safeStringify(obj: unknown, maxLength = 1000): string {
|
||||
try {
|
||||
const str = JSON.stringify(obj, (key, value) => {
|
||||
// Mask sensitive values
|
||||
if (
|
||||
key === 'client_secret' ||
|
||||
key === 'Authorization' ||
|
||||
key.toLowerCase().includes('token') ||
|
||||
key.toLowerCase().includes('password')
|
||||
) {
|
||||
return typeof value === 'string' && value.length > 6
|
||||
? `${value.substring(0, 3)}...${value.substring(value.length - 3)}`
|
||||
: '***MASKED***';
|
||||
}
|
||||
return value;
|
||||
});
|
||||
|
||||
if (str && str.length > maxLength) {
|
||||
return `${str.substring(0, maxLength)}... (truncated)`;
|
||||
}
|
||||
return str;
|
||||
} catch (error) {
|
||||
return `[Error stringifying object: ${(error as Error).message}]`;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to log headers without revealing sensitive information
|
||||
* @param headers - Headers object to log
|
||||
* @returns Stringified headers with sensitive data masked
|
||||
*/
|
||||
export function logHeaders(headers: Headers | undefined | null): string {
|
||||
const headerObj: Record<string, string> = {};
|
||||
if (!headers || typeof headers.entries !== 'function') {
|
||||
return 'No headers available';
|
||||
}
|
||||
for (const [key, value] of headers.entries()) {
|
||||
if (key.toLowerCase() === 'authorization' || key.toLowerCase().includes('secret')) {
|
||||
headerObj[key] = '***MASKED***';
|
||||
} else {
|
||||
headerObj[key] = value;
|
||||
}
|
||||
}
|
||||
return safeStringify(headerObj);
|
||||
}
|
||||
143
packages/api/src/utils/tokenizer.spec.ts
Normal file
143
packages/api/src/utils/tokenizer.spec.ts
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
/**
|
||||
* @file Tokenizer.spec.cjs
|
||||
*
|
||||
* Tests the real TokenizerSingleton (no mocking of `tiktoken`).
|
||||
* Make sure to install `tiktoken` and have it configured properly.
|
||||
*/
|
||||
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { Tiktoken } from 'tiktoken';
|
||||
import Tokenizer from './tokenizer';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('Tokenizer', () => {
|
||||
it('should be a singleton (same instance)', async () => {
|
||||
const AnotherTokenizer = await import('./tokenizer'); // same path
|
||||
expect(Tokenizer).toBe(AnotherTokenizer.default);
|
||||
});
|
||||
|
||||
describe('getTokenizer', () => {
|
||||
it('should create an encoder for an explicit model name (e.g., "gpt-4")', () => {
|
||||
// The real `encoding_for_model` will be called internally
|
||||
// as soon as we pass isModelName = true.
|
||||
const tokenizer = Tokenizer.getTokenizer('gpt-4', true);
|
||||
|
||||
// Basic sanity checks
|
||||
expect(tokenizer).toBeDefined();
|
||||
// You can optionally check certain properties from `tiktoken` if they exist
|
||||
// e.g., expect(typeof tokenizer.encode).toBe('function');
|
||||
});
|
||||
|
||||
it('should create an encoder for a known encoding (e.g., "cl100k_base")', () => {
|
||||
// The real `get_encoding` will be called internally
|
||||
// as soon as we pass isModelName = false.
|
||||
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
|
||||
expect(tokenizer).toBeDefined();
|
||||
// e.g., expect(typeof tokenizer.encode).toBe('function');
|
||||
});
|
||||
|
||||
it('should return cached tokenizer if previously fetched', () => {
|
||||
const tokenizer1 = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
const tokenizer2 = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
// Should be the exact same instance from the cache
|
||||
expect(tokenizer1).toBe(tokenizer2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('freeAndResetAllEncoders', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should free all encoders and reset tokenizerCallsCount to 1', () => {
|
||||
// By creating two different encodings, we populate the cache
|
||||
Tokenizer.getTokenizer('cl100k_base', false);
|
||||
Tokenizer.getTokenizer('r50k_base', false);
|
||||
|
||||
// Now free them
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
// The internal cache is cleared
|
||||
expect(Tokenizer.tokenizersCache['cl100k_base']).toBeUndefined();
|
||||
expect(Tokenizer.tokenizersCache['r50k_base']).toBeUndefined();
|
||||
|
||||
// tokenizerCallsCount is reset to 1
|
||||
expect(Tokenizer.tokenizerCallsCount).toBe(1);
|
||||
});
|
||||
|
||||
it('should catch and log errors if freeing fails', () => {
|
||||
// Mock logger.error before the test
|
||||
const mockLoggerError = jest.spyOn(logger, 'error');
|
||||
|
||||
// Set up a problematic tokenizer in the cache
|
||||
Tokenizer.tokenizersCache['cl100k_base'] = {
|
||||
free() {
|
||||
throw new Error('Intentional free error');
|
||||
},
|
||||
} as unknown as Tiktoken;
|
||||
|
||||
// Should not throw uncaught errors
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
// Verify logger.error was called with correct arguments
|
||||
expect(mockLoggerError).toHaveBeenCalledWith(
|
||||
'[Tokenizer] Free and reset encoders error',
|
||||
expect.any(Error),
|
||||
);
|
||||
|
||||
// Clean up
|
||||
mockLoggerError.mockRestore();
|
||||
Tokenizer.tokenizersCache = {};
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTokenCount', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
});
|
||||
|
||||
it('should return the number of tokens in the given text', () => {
|
||||
const text = 'Hello, world!';
|
||||
const count = Tokenizer.getTokenCount(text, 'cl100k_base');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should reset encoders if an error is thrown', () => {
|
||||
// We can simulate an error by temporarily overriding the selected tokenizer's `encode` method.
|
||||
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
const originalEncode = tokenizer.encode;
|
||||
tokenizer.encode = () => {
|
||||
throw new Error('Forced error');
|
||||
};
|
||||
|
||||
// Despite the forced error, the code should catch and reset, then re-encode
|
||||
const count = Tokenizer.getTokenCount('Hello again', 'cl100k_base');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
|
||||
// Restore the original encode
|
||||
tokenizer.encode = originalEncode;
|
||||
});
|
||||
|
||||
it('should reset tokenizers after 25 calls', () => {
|
||||
// Spy on freeAndResetAllEncoders
|
||||
const resetSpy = jest.spyOn(Tokenizer, 'freeAndResetAllEncoders');
|
||||
|
||||
// Make 24 calls; should NOT reset yet
|
||||
for (let i = 0; i < 24; i++) {
|
||||
Tokenizer.getTokenCount('test text', 'cl100k_base');
|
||||
}
|
||||
expect(resetSpy).not.toHaveBeenCalled();
|
||||
|
||||
// 25th call triggers the reset
|
||||
Tokenizer.getTokenCount('the 25th call!', 'cl100k_base');
|
||||
expect(resetSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
78
packages/api/src/utils/tokenizer.ts
Normal file
78
packages/api/src/utils/tokenizer.ts
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { encoding_for_model as encodingForModel, get_encoding as getEncoding } from 'tiktoken';
|
||||
import type { Tiktoken, TiktokenModel, TiktokenEncoding } from 'tiktoken';
|
||||
|
||||
interface TokenizerOptions {
|
||||
debug?: boolean;
|
||||
}
|
||||
|
||||
class Tokenizer {
|
||||
tokenizersCache: Record<string, Tiktoken>;
|
||||
tokenizerCallsCount: number;
|
||||
private options?: TokenizerOptions;
|
||||
|
||||
constructor() {
|
||||
this.tokenizersCache = {};
|
||||
this.tokenizerCallsCount = 0;
|
||||
}
|
||||
|
||||
getTokenizer(
|
||||
encoding: TiktokenModel | TiktokenEncoding,
|
||||
isModelName = false,
|
||||
extendSpecialTokens: Record<string, number> = {},
|
||||
): Tiktoken {
|
||||
let tokenizer: Tiktoken;
|
||||
if (this.tokenizersCache[encoding]) {
|
||||
tokenizer = this.tokenizersCache[encoding];
|
||||
} else {
|
||||
if (isModelName) {
|
||||
tokenizer = encodingForModel(encoding as TiktokenModel, extendSpecialTokens);
|
||||
} else {
|
||||
tokenizer = getEncoding(encoding as TiktokenEncoding, extendSpecialTokens);
|
||||
}
|
||||
this.tokenizersCache[encoding] = tokenizer;
|
||||
}
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
freeAndResetAllEncoders(): void {
|
||||
try {
|
||||
Object.keys(this.tokenizersCache).forEach((key) => {
|
||||
if (this.tokenizersCache[key]) {
|
||||
this.tokenizersCache[key].free();
|
||||
delete this.tokenizersCache[key];
|
||||
}
|
||||
});
|
||||
this.tokenizerCallsCount = 1;
|
||||
} catch (error) {
|
||||
logger.error('[Tokenizer] Free and reset encoders error', error);
|
||||
}
|
||||
}
|
||||
|
||||
resetTokenizersIfNecessary(): void {
|
||||
if (this.tokenizerCallsCount >= 25) {
|
||||
if (this.options?.debug) {
|
||||
logger.debug('[Tokenizer] freeAndResetAllEncoders: reached 25 encodings, resetting...');
|
||||
}
|
||||
this.freeAndResetAllEncoders();
|
||||
}
|
||||
this.tokenizerCallsCount++;
|
||||
}
|
||||
|
||||
getTokenCount(text: string, encoding: TiktokenModel | TiktokenEncoding = 'cl100k_base'): number {
|
||||
this.resetTokenizersIfNecessary();
|
||||
try {
|
||||
const tokenizer = this.getTokenizer(encoding);
|
||||
return tokenizer.encode(text, 'all').length;
|
||||
} catch (error) {
|
||||
logger.error('[Tokenizer] Error getting token count:', error);
|
||||
this.freeAndResetAllEncoders();
|
||||
const tokenizer = this.getTokenizer(encoding);
|
||||
return tokenizer.encode(text, 'all').length;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const TokenizerSingleton = new Tokenizer();
|
||||
|
||||
export default TokenizerSingleton;
|
||||
|
|
@ -18,7 +18,10 @@
|
|||
"isolatedModules": true,
|
||||
"noEmit": true,
|
||||
"sourceMap": true,
|
||||
"baseUrl": "."
|
||||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"~/*": ["./src/*"]
|
||||
}
|
||||
},
|
||||
"ts-node": {
|
||||
"experimentalSpecifierResolution": "node",
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "librechat-data-provider",
|
||||
"version": "0.7.86",
|
||||
"version": "0.7.88",
|
||||
"description": "data services for librechat apps",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.es.js",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import axios from 'axios';
|
||||
import { z } from 'zod';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import axios from 'axios';
|
||||
import type { OpenAPIV3 } from 'openapi-types';
|
||||
import type { ParametersSchema } from '../src/actions';
|
||||
import type { FlowchartSchema } from './openapiSpecs';
|
||||
import {
|
||||
createURL,
|
||||
resolveRef,
|
||||
|
|
@ -15,9 +17,7 @@ import {
|
|||
scholarAIOpenapiSpec,
|
||||
swapidev,
|
||||
} from './openapiSpecs';
|
||||
import { AuthorizationTypeEnum, AuthTypeEnum } from '../src/types/assistants';
|
||||
import type { FlowchartSchema } from './openapiSpecs';
|
||||
import type { ParametersSchema } from '../src/actions';
|
||||
import { AuthorizationTypeEnum, AuthTypeEnum } from '../src/types/agents';
|
||||
|
||||
jest.mock('axios');
|
||||
const mockedAxios = axios as jest.Mocked<typeof axios>;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,28 @@
|
|||
import { StdioOptionsSchema, StreamableHTTPOptionsSchema, processMCPEnv, MCPOptions } from '../src/mcp';
|
||||
import type { TUser } from 'librechat-data-provider';
|
||||
import {
|
||||
StreamableHTTPOptionsSchema,
|
||||
StdioOptionsSchema,
|
||||
processMCPEnv,
|
||||
MCPOptions,
|
||||
} from '../src/mcp';
|
||||
|
||||
// Helper function to create test user objects
|
||||
function createTestUser(
|
||||
overrides: Partial<TUser> & Record<string, unknown> = {},
|
||||
): TUser & Record<string, unknown> {
|
||||
return {
|
||||
id: 'test-user-id',
|
||||
username: 'testuser',
|
||||
email: 'test@example.com',
|
||||
name: 'Test User',
|
||||
avatar: 'https://example.com/avatar.png',
|
||||
provider: 'email',
|
||||
role: 'user',
|
||||
createdAt: new Date('2021-01-01').toISOString(),
|
||||
updatedAt: new Date('2021-01-01').toISOString(),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe('Environment Variable Extraction (MCP)', () => {
|
||||
const originalEnv = process.env;
|
||||
|
|
@ -91,13 +115,13 @@ describe('Environment Variable Extraction (MCP)', () => {
|
|||
|
||||
// Type is now required, so parsing should fail
|
||||
expect(() => StreamableHTTPOptionsSchema.parse(options)).toThrow();
|
||||
|
||||
|
||||
// With type provided, it should pass
|
||||
const validOptions = {
|
||||
type: 'streamable-http' as const,
|
||||
url: 'https://example.com/api',
|
||||
};
|
||||
|
||||
|
||||
const result = StreamableHTTPOptionsSchema.parse(validOptions);
|
||||
expect(result.type).toBe('streamable-http');
|
||||
});
|
||||
|
|
@ -113,7 +137,7 @@ describe('Environment Variable Extraction (MCP)', () => {
|
|||
};
|
||||
|
||||
const result = StreamableHTTPOptionsSchema.parse(options);
|
||||
|
||||
|
||||
expect(result.headers).toEqual(options.headers);
|
||||
});
|
||||
});
|
||||
|
|
@ -165,7 +189,7 @@ describe('Environment Variable Extraction (MCP)', () => {
|
|||
});
|
||||
|
||||
it('should process user ID in headers field', () => {
|
||||
const userId = 'test-user-123';
|
||||
const user = createTestUser({ id: 'test-user-123' });
|
||||
const obj: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com',
|
||||
|
|
@ -176,7 +200,7 @@ describe('Environment Variable Extraction (MCP)', () => {
|
|||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, userId);
|
||||
const result = processMCPEnv(obj, user);
|
||||
|
||||
expect('headers' in result && result.headers).toEqual({
|
||||
Authorization: 'test-api-key-value',
|
||||
|
|
@ -217,15 +241,15 @@ describe('Environment Variable Extraction (MCP)', () => {
|
|||
};
|
||||
|
||||
// Process for two different users
|
||||
const user1Id = 'user-123';
|
||||
const user2Id = 'user-456';
|
||||
const user1 = createTestUser({ id: 'user-123' });
|
||||
const user2 = createTestUser({ id: 'user-456' });
|
||||
|
||||
const resultUser1 = processMCPEnv(baseConfig, user1Id);
|
||||
const resultUser2 = processMCPEnv(baseConfig, user2Id);
|
||||
const resultUser1 = processMCPEnv(baseConfig, user1);
|
||||
const resultUser2 = processMCPEnv(baseConfig, user2);
|
||||
|
||||
// Verify each has the correct user ID
|
||||
expect('headers' in resultUser1 && resultUser1.headers?.['User-Id']).toBe(user1Id);
|
||||
expect('headers' in resultUser2 && resultUser2.headers?.['User-Id']).toBe(user2Id);
|
||||
expect('headers' in resultUser1 && resultUser1.headers?.['User-Id']).toBe('user-123');
|
||||
expect('headers' in resultUser2 && resultUser2.headers?.['User-Id']).toBe('user-456');
|
||||
|
||||
// Verify they're different objects
|
||||
expect(resultUser1).not.toBe(resultUser2);
|
||||
|
|
@ -239,11 +263,11 @@ describe('Environment Variable Extraction (MCP)', () => {
|
|||
expect(baseConfig.headers?.['User-Id']).toBe('{{LIBRECHAT_USER_ID}}');
|
||||
|
||||
// Second user's config should be unchanged
|
||||
expect('headers' in resultUser2 && resultUser2.headers?.['User-Id']).toBe(user2Id);
|
||||
expect('headers' in resultUser2 && resultUser2.headers?.['User-Id']).toBe('user-456');
|
||||
});
|
||||
|
||||
it('should process headers in streamable-http options', () => {
|
||||
const userId = 'test-user-123';
|
||||
const user = createTestUser({ id: 'test-user-123' });
|
||||
const obj: MCPOptions = {
|
||||
type: 'streamable-http',
|
||||
url: 'https://example.com',
|
||||
|
|
@ -254,7 +278,7 @@ describe('Environment Variable Extraction (MCP)', () => {
|
|||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, userId);
|
||||
const result = processMCPEnv(obj, user);
|
||||
|
||||
expect('headers' in result && result.headers).toEqual({
|
||||
Authorization: 'test-api-key-value',
|
||||
|
|
@ -262,7 +286,7 @@ describe('Environment Variable Extraction (MCP)', () => {
|
|||
'Content-Type': 'application/json',
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
it('should maintain streamable-http type in processed options', () => {
|
||||
const obj: MCPOptions = {
|
||||
type: 'streamable-http',
|
||||
|
|
@ -273,5 +297,416 @@ describe('Environment Variable Extraction (MCP)', () => {
|
|||
|
||||
expect(result.type).toBe('streamable-http');
|
||||
});
|
||||
|
||||
it('should process dynamic user fields in headers', () => {
|
||||
const user = createTestUser({
|
||||
id: 'user-123',
|
||||
email: 'test@example.com',
|
||||
username: 'testuser',
|
||||
openidId: 'openid-123',
|
||||
googleId: 'google-456',
|
||||
emailVerified: true,
|
||||
role: 'admin',
|
||||
});
|
||||
const obj: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com',
|
||||
headers: {
|
||||
'User-Email': '{{LIBRECHAT_USER_EMAIL}}',
|
||||
'User-Name': '{{LIBRECHAT_USER_USERNAME}}',
|
||||
OpenID: '{{LIBRECHAT_USER_OPENIDID}}',
|
||||
'Google-ID': '{{LIBRECHAT_USER_GOOGLEID}}',
|
||||
'Email-Verified': '{{LIBRECHAT_USER_EMAILVERIFIED}}',
|
||||
'User-Role': '{{LIBRECHAT_USER_ROLE}}',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user);
|
||||
|
||||
expect('headers' in result && result.headers).toEqual({
|
||||
'User-Email': 'test@example.com',
|
||||
'User-Name': 'testuser',
|
||||
OpenID: 'openid-123',
|
||||
'Google-ID': 'google-456',
|
||||
'Email-Verified': 'true',
|
||||
'User-Role': 'admin',
|
||||
'Content-Type': 'application/json',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle missing user fields gracefully', () => {
|
||||
const user = createTestUser({
|
||||
id: 'user-123',
|
||||
email: 'test@example.com',
|
||||
username: undefined, // explicitly set to undefined to test missing field
|
||||
});
|
||||
const obj: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com',
|
||||
headers: {
|
||||
'User-Email': '{{LIBRECHAT_USER_EMAIL}}',
|
||||
'User-Name': '{{LIBRECHAT_USER_USERNAME}}',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user);
|
||||
|
||||
expect('headers' in result && result.headers).toEqual({
|
||||
'User-Email': 'test@example.com',
|
||||
'User-Name': '', // Empty string for missing field
|
||||
'Content-Type': 'application/json',
|
||||
});
|
||||
});
|
||||
|
||||
it('should process user fields in env variables', () => {
|
||||
const user = createTestUser({
|
||||
id: 'user-123',
|
||||
email: 'test@example.com',
|
||||
ldapId: 'ldap-user-123',
|
||||
});
|
||||
const obj: MCPOptions = {
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
env: {
|
||||
USER_EMAIL: '{{LIBRECHAT_USER_EMAIL}}',
|
||||
LDAP_ID: '{{LIBRECHAT_USER_LDAPID}}',
|
||||
API_KEY: '${TEST_API_KEY}',
|
||||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user);
|
||||
|
||||
expect('env' in result && result.env).toEqual({
|
||||
USER_EMAIL: 'test@example.com',
|
||||
LDAP_ID: 'ldap-user-123',
|
||||
API_KEY: 'test-api-key-value',
|
||||
});
|
||||
});
|
||||
|
||||
it('should process user fields in URL', () => {
|
||||
const user = createTestUser({
|
||||
id: 'user-123',
|
||||
username: 'testuser',
|
||||
});
|
||||
const obj: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com/api/{{LIBRECHAT_USER_USERNAME}}/stream',
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user);
|
||||
|
||||
expect('url' in result && result.url).toBe('https://example.com/api/testuser/stream');
|
||||
});
|
||||
|
||||
it('should handle boolean user fields', () => {
|
||||
const user = createTestUser({
|
||||
id: 'user-123',
|
||||
emailVerified: true,
|
||||
twoFactorEnabled: false,
|
||||
termsAccepted: true,
|
||||
});
|
||||
const obj: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com',
|
||||
headers: {
|
||||
'Email-Verified': '{{LIBRECHAT_USER_EMAILVERIFIED}}',
|
||||
'Two-Factor': '{{LIBRECHAT_USER_TWOFACTORENABLED}}',
|
||||
'Terms-Accepted': '{{LIBRECHAT_USER_TERMSACCEPTED}}',
|
||||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user);
|
||||
|
||||
expect('headers' in result && result.headers).toEqual({
|
||||
'Email-Verified': 'true',
|
||||
'Two-Factor': 'false',
|
||||
'Terms-Accepted': 'true',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not process sensitive fields like password', () => {
|
||||
const user = createTestUser({
|
||||
id: 'user-123',
|
||||
email: 'test@example.com',
|
||||
password: 'secret-password',
|
||||
});
|
||||
const obj: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com',
|
||||
headers: {
|
||||
'User-Email': '{{LIBRECHAT_USER_EMAIL}}',
|
||||
'User-Password': '{{LIBRECHAT_USER_PASSWORD}}', // This should not be processed
|
||||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user);
|
||||
|
||||
expect('headers' in result && result.headers).toEqual({
|
||||
'User-Email': 'test@example.com',
|
||||
'User-Password': '{{LIBRECHAT_USER_PASSWORD}}', // Unchanged
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle multiple occurrences of the same placeholder', () => {
|
||||
const user = createTestUser({
|
||||
id: 'user-123',
|
||||
email: 'test@example.com',
|
||||
});
|
||||
const obj: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com',
|
||||
headers: {
|
||||
'Primary-Email': '{{LIBRECHAT_USER_EMAIL}}',
|
||||
'Secondary-Email': '{{LIBRECHAT_USER_EMAIL}}',
|
||||
'Backup-Email': '{{LIBRECHAT_USER_EMAIL}}',
|
||||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user);
|
||||
|
||||
expect('headers' in result && result.headers).toEqual({
|
||||
'Primary-Email': 'test@example.com',
|
||||
'Secondary-Email': 'test@example.com',
|
||||
'Backup-Email': 'test@example.com',
|
||||
});
|
||||
});
|
||||
|
||||
it('should support both id and _id properties for LIBRECHAT_USER_ID', () => {
|
||||
// Test with 'id' property
|
||||
const userWithId = createTestUser({
|
||||
id: 'user-123',
|
||||
email: 'test@example.com',
|
||||
});
|
||||
const obj1: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com',
|
||||
headers: {
|
||||
'User-Id': '{{LIBRECHAT_USER_ID}}',
|
||||
},
|
||||
};
|
||||
|
||||
const result1 = processMCPEnv(obj1, userWithId);
|
||||
expect('headers' in result1 && result1.headers?.['User-Id']).toBe('user-123');
|
||||
|
||||
// Test with '_id' property only (should not work since we only check 'id')
|
||||
const userWithUnderscore = createTestUser({
|
||||
id: undefined, // Remove default id to test _id
|
||||
_id: 'user-456',
|
||||
email: 'test@example.com',
|
||||
});
|
||||
const obj2: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com',
|
||||
headers: {
|
||||
'User-Id': '{{LIBRECHAT_USER_ID}}',
|
||||
},
|
||||
};
|
||||
|
||||
const result2 = processMCPEnv(obj2, userWithUnderscore);
|
||||
// Since we don't check _id, the placeholder should remain unchanged
|
||||
expect('headers' in result2 && result2.headers?.['User-Id']).toBe('{{LIBRECHAT_USER_ID}}');
|
||||
|
||||
// Test with both properties (id takes precedence)
|
||||
const userWithBoth = createTestUser({
|
||||
id: 'user-789',
|
||||
_id: 'user-000',
|
||||
email: 'test@example.com',
|
||||
});
|
||||
const obj3: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com',
|
||||
headers: {
|
||||
'User-Id': '{{LIBRECHAT_USER_ID}}',
|
||||
},
|
||||
};
|
||||
|
||||
const result3 = processMCPEnv(obj3, userWithBoth);
|
||||
expect('headers' in result3 && result3.headers?.['User-Id']).toBe('user-789');
|
||||
});
|
||||
|
||||
it('should process customUserVars in env field', () => {
|
||||
const user = createTestUser();
|
||||
const customUserVars = {
|
||||
CUSTOM_VAR_1: 'custom-value-1',
|
||||
CUSTOM_VAR_2: 'custom-value-2',
|
||||
};
|
||||
const obj: MCPOptions = {
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
env: {
|
||||
VAR_A: '{{CUSTOM_VAR_1}}',
|
||||
VAR_B: 'Value with {{CUSTOM_VAR_2}}',
|
||||
VAR_C: '${TEST_API_KEY}',
|
||||
VAR_D: '{{LIBRECHAT_USER_EMAIL}}',
|
||||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user, customUserVars);
|
||||
|
||||
expect('env' in result && result.env).toEqual({
|
||||
VAR_A: 'custom-value-1',
|
||||
VAR_B: 'Value with custom-value-2',
|
||||
VAR_C: 'test-api-key-value',
|
||||
VAR_D: 'test@example.com',
|
||||
});
|
||||
});
|
||||
|
||||
it('should process customUserVars in headers field', () => {
|
||||
const user = createTestUser();
|
||||
const customUserVars = {
|
||||
USER_TOKEN: 'user-specific-token',
|
||||
REGION: 'us-west-1',
|
||||
};
|
||||
const obj: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com/api',
|
||||
headers: {
|
||||
Authorization: 'Bearer {{USER_TOKEN}}',
|
||||
'X-Region': '{{REGION}}',
|
||||
'X-System-Key': '${TEST_API_KEY}',
|
||||
'X-User-Id': '{{LIBRECHAT_USER_ID}}',
|
||||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user, customUserVars);
|
||||
|
||||
expect('headers' in result && result.headers).toEqual({
|
||||
Authorization: 'Bearer user-specific-token',
|
||||
'X-Region': 'us-west-1',
|
||||
'X-System-Key': 'test-api-key-value',
|
||||
'X-User-Id': 'test-user-id',
|
||||
});
|
||||
});
|
||||
|
||||
it('should process customUserVars in URL field', () => {
|
||||
const user = createTestUser();
|
||||
const customUserVars = {
|
||||
API_VERSION: 'v2',
|
||||
TENANT_ID: 'tenant123',
|
||||
};
|
||||
const obj: MCPOptions = {
|
||||
type: 'websocket',
|
||||
url: 'wss://example.com/{{TENANT_ID}}/api/{{API_VERSION}}?user={{LIBRECHAT_USER_ID}}&key=${TEST_API_KEY}',
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user, customUserVars);
|
||||
|
||||
expect('url' in result && result.url).toBe(
|
||||
'wss://example.com/tenant123/api/v2?user=test-user-id&key=test-api-key-value',
|
||||
);
|
||||
});
|
||||
|
||||
it('should prioritize customUserVars over user fields and system env vars if placeholders are the same (though not recommended)', () => {
|
||||
// This tests the order of operations: customUserVars -> userFields -> systemEnv
|
||||
// BUt it's generally not recommended to have overlapping placeholder names.
|
||||
process.env.LIBRECHAT_USER_EMAIL = 'system-email-should-be-overridden';
|
||||
const user = createTestUser({ email: 'user-email-should-be-overridden' });
|
||||
const customUserVars = {
|
||||
LIBRECHAT_USER_EMAIL: 'custom-email-wins',
|
||||
};
|
||||
const obj: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com/api',
|
||||
headers: {
|
||||
'Test-Email': '{{LIBRECHAT_USER_EMAIL}}', // Placeholder that could match custom, user, or system
|
||||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user, customUserVars);
|
||||
expect('headers' in result && result.headers?.['Test-Email']).toBe('custom-email-wins');
|
||||
|
||||
// Clean up env var
|
||||
delete process.env.LIBRECHAT_USER_EMAIL;
|
||||
});
|
||||
|
||||
it('should handle customUserVars with no matching placeholders', () => {
|
||||
const user = createTestUser();
|
||||
const customUserVars = {
|
||||
UNUSED_VAR: 'unused-value',
|
||||
};
|
||||
const obj: MCPOptions = {
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
env: {
|
||||
API_KEY: '${TEST_API_KEY}',
|
||||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user, customUserVars);
|
||||
expect('env' in result && result.env).toEqual({
|
||||
API_KEY: 'test-api-key-value',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle placeholders with no matching customUserVars (falling back to user/system vars)', () => {
|
||||
const user = createTestUser({ email: 'user-provided-email@example.com' });
|
||||
// No customUserVars provided or customUserVars is empty
|
||||
const customUserVars = {};
|
||||
const obj: MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://example.com/api',
|
||||
headers: {
|
||||
'User-Email-Header': '{{LIBRECHAT_USER_EMAIL}}', // Should use user.email
|
||||
'System-Key-Header': '${TEST_API_KEY}', // Should use process.env.TEST_API_KEY
|
||||
'Non-Existent-Custom': '{{NON_EXISTENT_CUSTOM_VAR}}', // Should remain as placeholder
|
||||
},
|
||||
};
|
||||
|
||||
const result = processMCPEnv(obj, user, customUserVars);
|
||||
expect('headers' in result && result.headers).toEqual({
|
||||
'User-Email-Header': 'user-provided-email@example.com',
|
||||
'System-Key-Header': 'test-api-key-value',
|
||||
'Non-Existent-Custom': '{{NON_EXISTENT_CUSTOM_VAR}}',
|
||||
});
|
||||
});
|
||||
|
||||
it('should correctly process a mix of all variable types', () => {
|
||||
const user = createTestUser({ id: 'userXYZ', username: 'john.doe' });
|
||||
const customUserVars = {
|
||||
CUSTOM_ENDPOINT_ID: 'ep123',
|
||||
ANOTHER_CUSTOM: 'another_val',
|
||||
};
|
||||
|
||||
const obj = {
|
||||
type: 'streamable-http' as const,
|
||||
url: 'https://{{CUSTOM_ENDPOINT_ID}}.example.com/users/{{LIBRECHAT_USER_USERNAME}}',
|
||||
headers: {
|
||||
'X-Auth-Token': '{{CUSTOM_TOKEN_FROM_USER_SETTINGS}}', // Assuming this would be a custom var
|
||||
'X-User-ID': '{{LIBRECHAT_USER_ID}}',
|
||||
'X-System-Test-Key': '${TEST_API_KEY}', // Using existing env var from beforeEach
|
||||
},
|
||||
env: {
|
||||
PROCESS_MODE: '{{PROCESS_MODE_CUSTOM}}', // Another custom var
|
||||
USER_HOME_DIR: '/home/{{LIBRECHAT_USER_USERNAME}}',
|
||||
SYSTEM_PATH: '${PATH}', // Example of a system env var
|
||||
},
|
||||
};
|
||||
|
||||
// Simulate customUserVars that would be passed, including those for headers and env
|
||||
const allCustomVarsForCall = {
|
||||
...customUserVars,
|
||||
CUSTOM_TOKEN_FROM_USER_SETTINGS: 'secretToken123!',
|
||||
PROCESS_MODE_CUSTOM: 'production',
|
||||
};
|
||||
|
||||
// Cast obj to MCPOptions when calling processMCPEnv.
|
||||
// This acknowledges the object might not strictly conform to one schema in the union,
|
||||
// but we are testing the function's ability to handle these properties if present.
|
||||
const result = processMCPEnv(obj as MCPOptions, user, allCustomVarsForCall);
|
||||
|
||||
expect('url' in result && result.url).toBe('https://ep123.example.com/users/john.doe');
|
||||
expect('headers' in result && result.headers).toEqual({
|
||||
'X-Auth-Token': 'secretToken123!',
|
||||
'X-User-ID': 'userXYZ',
|
||||
'X-System-Test-Key': 'test-api-key-value', // Expecting value of TEST_API_KEY
|
||||
});
|
||||
expect('env' in result && result.env).toEqual({
|
||||
PROCESS_MODE: 'production',
|
||||
USER_HOME_DIR: '/home/john.doe',
|
||||
SYSTEM_PATH: process.env.PATH, // Actual value of PATH from the test environment
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -3,15 +3,11 @@ import _axios from 'axios';
|
|||
import { URL } from 'url';
|
||||
import crypto from 'crypto';
|
||||
import { load } from 'js-yaml';
|
||||
import type {
|
||||
FunctionTool,
|
||||
Schema,
|
||||
Reference,
|
||||
ActionMetadata,
|
||||
ActionMetadataRuntime,
|
||||
} from './types/assistants';
|
||||
import type { ActionMetadata, ActionMetadataRuntime } from './types/agents';
|
||||
import type { FunctionTool, Schema, Reference } from './types/assistants';
|
||||
import { AuthTypeEnum, AuthorizationTypeEnum } from './types/agents';
|
||||
import type { OpenAPIV3 } from 'openapi-types';
|
||||
import { Tools, AuthTypeEnum, AuthorizationTypeEnum } from './types/assistants';
|
||||
import { Tools } from './types/assistants';
|
||||
|
||||
export type ParametersSchema = {
|
||||
type: string;
|
||||
|
|
|
|||
|
|
@ -261,6 +261,7 @@ export const getUserPromptPreferences = () => `${prompts()}/preferences`;
|
|||
export const roles = () => '/api/roles';
|
||||
export const getRole = (roleName: string) => `${roles()}/${roleName.toLowerCase()}`;
|
||||
export const updatePromptPermissions = (roleName: string) => `${getRole(roleName)}/prompts`;
|
||||
export const updateMemoryPermissions = (roleName: string) => `${getRole(roleName)}/memories`;
|
||||
export const updateAgentPermissions = (roleName: string) => `${getRole(roleName)}/agents`;
|
||||
|
||||
/* Conversation Tags */
|
||||
|
|
@ -290,3 +291,8 @@ export const confirmTwoFactor = () => '/api/auth/2fa/confirm';
|
|||
export const disableTwoFactor = () => '/api/auth/2fa/disable';
|
||||
export const regenerateBackupCodes = () => '/api/auth/2fa/backup/regenerate';
|
||||
export const verifyTwoFactorTemp = () => '/api/auth/2fa/verify-temp';
|
||||
|
||||
/* Memories */
|
||||
export const memories = () => '/api/memories';
|
||||
export const memory = (key: string) => `${memories()}/${encodeURIComponent(key)}`;
|
||||
export const memoryPreferences = () => `${memories()}/preferences`;
|
||||
|
|
|
|||
|
|
@ -244,21 +244,26 @@ export const defaultAgentCapabilities = [
|
|||
AgentCapabilities.ocr,
|
||||
];
|
||||
|
||||
export const agentsEndpointSChema = baseEndpointSchema.merge(
|
||||
z.object({
|
||||
/* agents specific */
|
||||
recursionLimit: z.number().optional(),
|
||||
disableBuilder: z.boolean().optional(),
|
||||
maxRecursionLimit: z.number().optional(),
|
||||
allowedProviders: z.array(z.union([z.string(), eModelEndpointSchema])).optional(),
|
||||
capabilities: z
|
||||
.array(z.nativeEnum(AgentCapabilities))
|
||||
.optional()
|
||||
.default(defaultAgentCapabilities),
|
||||
}),
|
||||
);
|
||||
export const agentsEndpointSchema = baseEndpointSchema
|
||||
.merge(
|
||||
z.object({
|
||||
/* agents specific */
|
||||
recursionLimit: z.number().optional(),
|
||||
disableBuilder: z.boolean().optional().default(false),
|
||||
maxRecursionLimit: z.number().optional(),
|
||||
allowedProviders: z.array(z.union([z.string(), eModelEndpointSchema])).optional(),
|
||||
capabilities: z
|
||||
.array(z.nativeEnum(AgentCapabilities))
|
||||
.optional()
|
||||
.default(defaultAgentCapabilities),
|
||||
}),
|
||||
)
|
||||
.default({
|
||||
disableBuilder: false,
|
||||
capabilities: defaultAgentCapabilities,
|
||||
});
|
||||
|
||||
export type TAgentsEndpoint = z.infer<typeof agentsEndpointSChema>;
|
||||
export type TAgentsEndpoint = z.infer<typeof agentsEndpointSchema>;
|
||||
|
||||
export const endpointSchema = baseEndpointSchema.merge(
|
||||
z.object({
|
||||
|
|
@ -493,6 +498,7 @@ export const intefaceSchema = z
|
|||
sidePanel: z.boolean().optional(),
|
||||
multiConvo: z.boolean().optional(),
|
||||
bookmarks: z.boolean().optional(),
|
||||
memories: z.boolean().optional(),
|
||||
presets: z.boolean().optional(),
|
||||
prompts: z.boolean().optional(),
|
||||
agents: z.boolean().optional(),
|
||||
|
|
@ -508,6 +514,7 @@ export const intefaceSchema = z
|
|||
presets: true,
|
||||
multiConvo: true,
|
||||
bookmarks: true,
|
||||
memories: true,
|
||||
prompts: true,
|
||||
agents: true,
|
||||
temporaryChat: true,
|
||||
|
|
@ -581,11 +588,24 @@ export type TStartupConfig = {
|
|||
scraperType?: ScraperTypes;
|
||||
rerankerType?: RerankerTypes;
|
||||
};
|
||||
mcpServers?: Record<
|
||||
string,
|
||||
{
|
||||
customUserVars: Record<
|
||||
string,
|
||||
{
|
||||
title: string;
|
||||
description: string;
|
||||
}
|
||||
>;
|
||||
}
|
||||
>;
|
||||
};
|
||||
|
||||
export enum OCRStrategy {
|
||||
MISTRAL_OCR = 'mistral_ocr',
|
||||
CUSTOM_OCR = 'custom_ocr',
|
||||
AZURE_MISTRAL_OCR = 'azure_mistral_ocr',
|
||||
}
|
||||
|
||||
export enum SearchCategories {
|
||||
|
|
@ -649,11 +669,35 @@ export const balanceSchema = z.object({
|
|||
refillAmount: z.number().optional().default(10000),
|
||||
});
|
||||
|
||||
export const memorySchema = z.object({
|
||||
disabled: z.boolean().optional(),
|
||||
validKeys: z.array(z.string()).optional(),
|
||||
tokenLimit: z.number().optional(),
|
||||
personalize: z.boolean().default(true),
|
||||
messageWindowSize: z.number().optional().default(5),
|
||||
agent: z
|
||||
.union([
|
||||
z.object({
|
||||
id: z.string(),
|
||||
}),
|
||||
z.object({
|
||||
provider: z.string(),
|
||||
model: z.string(),
|
||||
instructions: z.string().optional(),
|
||||
model_parameters: z.record(z.any()).optional(),
|
||||
}),
|
||||
])
|
||||
.optional(),
|
||||
});
|
||||
|
||||
export type TMemoryConfig = z.infer<typeof memorySchema>;
|
||||
|
||||
export const configSchema = z.object({
|
||||
version: z.string(),
|
||||
cache: z.boolean().default(true),
|
||||
ocr: ocrSchema.optional(),
|
||||
webSearch: webSearchSchema.optional(),
|
||||
memory: memorySchema.optional(),
|
||||
secureImageLinks: z.boolean().optional(),
|
||||
imageOutputType: z.nativeEnum(EImageOutputType).default(EImageOutputType.PNG),
|
||||
includedTools: z.array(z.string()).optional(),
|
||||
|
|
@ -694,7 +738,7 @@ export const configSchema = z.object({
|
|||
[EModelEndpoint.azureOpenAI]: azureEndpointSchema.optional(),
|
||||
[EModelEndpoint.azureAssistants]: assistantEndpointSchema.optional(),
|
||||
[EModelEndpoint.assistants]: assistantEndpointSchema.optional(),
|
||||
[EModelEndpoint.agents]: agentsEndpointSChema.optional(),
|
||||
[EModelEndpoint.agents]: agentsEndpointSchema.optional(),
|
||||
[EModelEndpoint.custom]: z.array(endpointSchema.partial()).optional(),
|
||||
[EModelEndpoint.bedrock]: baseEndpointSchema.optional(),
|
||||
})
|
||||
|
|
@ -853,7 +897,6 @@ export const defaultModels = {
|
|||
[EModelEndpoint.assistants]: [...sharedOpenAIModels, 'chatgpt-4o-latest'],
|
||||
[EModelEndpoint.agents]: sharedOpenAIModels, // TODO: Add agent models (agentsModels)
|
||||
[EModelEndpoint.google]: [
|
||||
// Shared Google Models between Vertex AI & Gen AI
|
||||
// Gemini 2.0 Models
|
||||
'gemini-2.0-flash-001',
|
||||
'gemini-2.0-flash-exp',
|
||||
|
|
@ -1104,6 +1147,10 @@ export enum CacheKeys {
|
|||
* Key for in-progress flow states.
|
||||
*/
|
||||
FLOWS = 'flows',
|
||||
/**
|
||||
* Key for individual MCP Tool Manifests.
|
||||
*/
|
||||
MCP_TOOLS = 'mcp_tools',
|
||||
/**
|
||||
* Key for pending chat requests (concurrency check)
|
||||
*/
|
||||
|
|
@ -1291,6 +1338,10 @@ export enum SettingsTabValues {
|
|||
* Chat input commands
|
||||
*/
|
||||
COMMANDS = 'commands',
|
||||
/**
|
||||
* Tab for Personalization Settings
|
||||
*/
|
||||
PERSONALIZATION = 'personalization',
|
||||
}
|
||||
|
||||
export enum STTProviders {
|
||||
|
|
@ -1328,7 +1379,7 @@ export enum Constants {
|
|||
/** Key for the app's version. */
|
||||
VERSION = 'v0.7.8',
|
||||
/** Key for the Custom Config's version (librechat.yaml). */
|
||||
CONFIG_VERSION = '1.2.6',
|
||||
CONFIG_VERSION = '1.2.8',
|
||||
/** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */
|
||||
NO_PARENT = '00000000-0000-0000-0000-000000000000',
|
||||
/** Standard value for the initial conversationId before a request is sent */
|
||||
|
|
@ -1355,6 +1406,8 @@ export enum Constants {
|
|||
GLOBAL_PROJECT_NAME = 'instance',
|
||||
/** Delimiter for MCP tools */
|
||||
mcp_delimiter = '_mcp_',
|
||||
/** Prefix for MCP plugins */
|
||||
mcp_prefix = 'mcp_',
|
||||
/** Placeholder Agent ID for Ephemeral Agents */
|
||||
EPHEMERAL_AGENT_ID = 'ephemeral',
|
||||
}
|
||||
|
|
@ -1398,6 +1451,10 @@ export enum LocalStorageKeys {
|
|||
LAST_CODE_TOGGLE_ = 'LAST_CODE_TOGGLE_',
|
||||
/** Last checked toggle for Web Search per conversation ID */
|
||||
LAST_WEB_SEARCH_TOGGLE_ = 'LAST_WEB_SEARCH_TOGGLE_',
|
||||
/** Key for the last selected agent provider */
|
||||
LAST_AGENT_PROVIDER = 'lastAgentProvider',
|
||||
/** Key for the last selected agent model */
|
||||
LAST_AGENT_MODEL = 'lastAgentModel',
|
||||
}
|
||||
|
||||
export enum ForkOptions {
|
||||
|
|
|
|||
|
|
@ -13,11 +13,11 @@ export default function createPayload(submission: t.TSubmission) {
|
|||
ephemeralAgent,
|
||||
} = submission;
|
||||
const { conversationId } = s.tConvoUpdateSchema.parse(conversation);
|
||||
const { endpoint, endpointType } = endpointOption as {
|
||||
const { endpoint: _e, endpointType } = endpointOption as {
|
||||
endpoint: s.EModelEndpoint;
|
||||
endpointType?: s.EModelEndpoint;
|
||||
};
|
||||
|
||||
const endpoint = _e as s.EModelEndpoint;
|
||||
let server = EndpointURLs[endpointType ?? endpoint];
|
||||
const isEphemeral = s.isEphemeralAgent(endpoint, ephemeralAgent);
|
||||
|
||||
|
|
@ -32,6 +32,7 @@ export default function createPayload(submission: t.TSubmission) {
|
|||
const payload: t.TPayload = {
|
||||
...userMessage,
|
||||
...endpointOption,
|
||||
endpoint,
|
||||
ephemeralAgent: isEphemeral ? ephemeralAgent : undefined,
|
||||
isContinued: !!(isEdited && isContinued),
|
||||
conversationId,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import type { AxiosResponse } from 'axios';
|
|||
import type * as t from './types';
|
||||
import * as endpoints from './api-endpoints';
|
||||
import * as a from './types/assistants';
|
||||
import * as ag from './types/agents';
|
||||
import * as m from './types/mutations';
|
||||
import * as q from './types/queries';
|
||||
import * as f from './types/files';
|
||||
|
|
@ -150,7 +151,11 @@ export const updateUserPlugins = (payload: t.TUpdateUserPlugins) => {
|
|||
|
||||
/* Config */
|
||||
|
||||
export const getStartupConfig = (): Promise<config.TStartupConfig> => {
|
||||
export const getStartupConfig = (): Promise<
|
||||
config.TStartupConfig & {
|
||||
mcpCustomUserVars?: Record<string, { title: string; description: string }>;
|
||||
}
|
||||
> => {
|
||||
return request.get(endpoints.config());
|
||||
};
|
||||
|
||||
|
|
@ -351,7 +356,7 @@ export const updateAction = (data: m.UpdateActionVariables): Promise<m.UpdateAct
|
|||
);
|
||||
};
|
||||
|
||||
export function getActions(): Promise<a.Action[]> {
|
||||
export function getActions(): Promise<ag.Action[]> {
|
||||
return request.get(
|
||||
endpoints.agents({
|
||||
path: 'actions',
|
||||
|
|
@ -407,7 +412,7 @@ export const updateAgent = ({
|
|||
|
||||
export const duplicateAgent = ({
|
||||
agent_id,
|
||||
}: m.DuplicateAgentBody): Promise<{ agent: a.Agent; actions: a.Action[] }> => {
|
||||
}: m.DuplicateAgentBody): Promise<{ agent: a.Agent; actions: ag.Action[] }> => {
|
||||
return request.post(
|
||||
endpoints.agents({
|
||||
path: `${agent_id}/duplicate`,
|
||||
|
|
@ -733,6 +738,12 @@ export function updateAgentPermissions(
|
|||
return request.put(endpoints.updateAgentPermissions(variables.roleName), variables.updates);
|
||||
}
|
||||
|
||||
export function updateMemoryPermissions(
|
||||
variables: m.UpdateMemoryPermVars,
|
||||
): Promise<m.UpdatePermResponse> {
|
||||
return request.put(endpoints.updateMemoryPermissions(variables.roleName), variables.updates);
|
||||
}
|
||||
|
||||
/* Tags */
|
||||
export function getConversationTags(): Promise<t.TConversationTagsResponse> {
|
||||
return request.get(endpoints.conversationTags());
|
||||
|
|
@ -814,3 +825,33 @@ export function verifyTwoFactorTemp(
|
|||
): Promise<t.TVerify2FATempResponse> {
|
||||
return request.post(endpoints.verifyTwoFactorTemp(), payload);
|
||||
}
|
||||
|
||||
/* Memories */
|
||||
export const getMemories = (): Promise<q.MemoriesResponse> => {
|
||||
return request.get(endpoints.memories());
|
||||
};
|
||||
|
||||
export const deleteMemory = (key: string): Promise<void> => {
|
||||
return request.delete(endpoints.memory(key));
|
||||
};
|
||||
|
||||
export const updateMemory = (
|
||||
key: string,
|
||||
value: string,
|
||||
originalKey?: string,
|
||||
): Promise<q.TUserMemory> => {
|
||||
return request.patch(endpoints.memory(originalKey || key), { key, value });
|
||||
};
|
||||
|
||||
export const updateMemoryPreferences = (preferences: {
|
||||
memories: boolean;
|
||||
}): Promise<{ updated: boolean; preferences: { memories: boolean } }> => {
|
||||
return request.patch(endpoints.memoryPreferences(), preferences);
|
||||
};
|
||||
|
||||
export const createMemory = (data: {
|
||||
key: string;
|
||||
value: string;
|
||||
}): Promise<{ created: boolean; memory: q.TUserMemory }> => {
|
||||
return request.post(endpoints.memories(), data);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { z } from 'zod';
|
||||
import { EModelEndpoint } from './schemas';
|
||||
import type { FileConfig, EndpointFileConfig } from './types/files';
|
||||
import type { EndpointFileConfig, FileConfig } from './types/files';
|
||||
|
||||
export const supportsFiles = {
|
||||
[EModelEndpoint.openAI]: true,
|
||||
|
|
@ -49,6 +49,8 @@ export const fullMimeTypesList = [
|
|||
'text/javascript',
|
||||
'image/gif',
|
||||
'image/png',
|
||||
'image/heic',
|
||||
'image/heif',
|
||||
'application/x-tar',
|
||||
'application/typescript',
|
||||
'application/xml',
|
||||
|
|
@ -80,6 +82,8 @@ export const codeInterpreterMimeTypesList = [
|
|||
'text/javascript',
|
||||
'image/gif',
|
||||
'image/png',
|
||||
'image/heic',
|
||||
'image/heif',
|
||||
'application/x-tar',
|
||||
'application/typescript',
|
||||
'application/xml',
|
||||
|
|
@ -105,18 +109,18 @@ export const retrievalMimeTypesList = [
|
|||
'text/plain',
|
||||
];
|
||||
|
||||
export const imageExtRegex = /\.(jpg|jpeg|png|gif|webp)$/i;
|
||||
export const imageExtRegex = /\.(jpg|jpeg|png|gif|webp|heic|heif)$/i;
|
||||
|
||||
export const excelMimeTypes =
|
||||
/^application\/(vnd\.ms-excel|msexcel|x-msexcel|x-ms-excel|x-excel|x-dos_ms_excel|xls|x-xls|vnd\.openxmlformats-officedocument\.spreadsheetml\.sheet)$/;
|
||||
|
||||
export const textMimeTypes =
|
||||
/^(text\/(x-c|x-csharp|tab-separated-values|x-c\+\+|x-java|html|markdown|x-php|x-python|x-script\.python|x-ruby|x-tex|plain|css|vtt|javascript|csv))$/;
|
||||
/^(text\/(x-c|x-csharp|tab-separated-values|x-c\+\+|x-h|x-java|html|markdown|x-php|x-python|x-script\.python|x-ruby|x-tex|plain|css|vtt|javascript|csv))$/;
|
||||
|
||||
export const applicationMimeTypes =
|
||||
/^(application\/(epub\+zip|csv|json|pdf|x-tar|typescript|vnd\.openxmlformats-officedocument\.(wordprocessingml\.document|presentationml\.presentation|spreadsheetml\.sheet)|xml|zip))$/;
|
||||
|
||||
export const imageMimeTypes = /^image\/(jpeg|gif|png|webp)$/;
|
||||
export const imageMimeTypes = /^image\/(jpeg|gif|png|webp|heic|heif)$/;
|
||||
|
||||
export const supportedMimeTypes = [
|
||||
textMimeTypes,
|
||||
|
|
@ -138,6 +142,7 @@ export const codeTypeMapping: { [key: string]: string } = {
|
|||
c: 'text/x-c',
|
||||
cs: 'text/x-csharp',
|
||||
cpp: 'text/x-c++',
|
||||
h: 'text/x-h',
|
||||
md: 'text/markdown',
|
||||
php: 'text/x-php',
|
||||
py: 'text/x-python',
|
||||
|
|
@ -155,7 +160,7 @@ export const codeTypeMapping: { [key: string]: string } = {
|
|||
};
|
||||
|
||||
export const retrievalMimeTypes = [
|
||||
/^(text\/(x-c|x-c\+\+|html|x-java|markdown|x-php|x-python|x-script\.python|x-ruby|x-tex|plain|vtt|xml))$/,
|
||||
/^(text\/(x-c|x-c\+\+|x-h|html|x-java|markdown|x-php|x-python|x-script\.python|x-ruby|x-tex|plain|vtt|xml))$/,
|
||||
/^(application\/(json|pdf|vnd\.openxmlformats-officedocument\.(wordprocessingml\.document|presentationml\.presentation)))$/,
|
||||
];
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ export * from './models';
|
|||
export * from './mcp';
|
||||
/* web search */
|
||||
export * from './web';
|
||||
/* memory */
|
||||
export * from './memory';
|
||||
/* RBAC */
|
||||
export * from './permissions';
|
||||
export * from './roles';
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ export enum QueryKeys {
|
|||
health = 'health',
|
||||
userTerms = 'userTerms',
|
||||
banner = 'banner',
|
||||
/* Memories */
|
||||
memories = 'memories',
|
||||
}
|
||||
|
||||
export enum MutationKeys {
|
||||
|
|
@ -71,4 +73,5 @@ export enum MutationKeys {
|
|||
updateRole = 'updateRole',
|
||||
enableTwoFactor = 'enableTwoFactor',
|
||||
verifyTwoFactor = 'verifyTwoFactor',
|
||||
updateMemoryPreferences = 'updateMemoryPreferences',
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import { z } from 'zod';
|
||||
import type { TUser } from './types';
|
||||
import { extractEnvVariable } from './utils';
|
||||
import { TokenExchangeMethodEnum } from './types/agents';
|
||||
|
||||
const BaseOptionsSchema = z.object({
|
||||
iconPath: z.string().optional(),
|
||||
|
|
@ -7,6 +9,45 @@ const BaseOptionsSchema = z.object({
|
|||
initTimeout: z.number().optional(),
|
||||
/** Controls visibility in chat dropdown menu (MCPSelect) */
|
||||
chatMenu: z.boolean().optional(),
|
||||
/**
|
||||
* Controls server instruction behavior:
|
||||
* - undefined/not set: No instructions included (default)
|
||||
* - true: Use server-provided instructions
|
||||
* - string: Use custom instructions (overrides server-provided)
|
||||
*/
|
||||
serverInstructions: z.union([z.boolean(), z.string()]).optional(),
|
||||
/**
|
||||
* OAuth configuration for SSE and Streamable HTTP transports
|
||||
* - Optional: OAuth can be auto-discovered on 401 responses
|
||||
* - Pre-configured values will skip discovery steps
|
||||
*/
|
||||
oauth: z
|
||||
.object({
|
||||
/** OAuth authorization endpoint (optional - can be auto-discovered) */
|
||||
authorization_url: z.string().url().optional(),
|
||||
/** OAuth token endpoint (optional - can be auto-discovered) */
|
||||
token_url: z.string().url().optional(),
|
||||
/** OAuth client ID (optional - can use dynamic registration) */
|
||||
client_id: z.string().optional(),
|
||||
/** OAuth client secret (optional - can use dynamic registration) */
|
||||
client_secret: z.string().optional(),
|
||||
/** OAuth scopes to request */
|
||||
scope: z.string().optional(),
|
||||
/** OAuth redirect URI (defaults to /api/mcp/{serverName}/oauth/callback) */
|
||||
redirect_uri: z.string().url().optional(),
|
||||
/** Token exchange method */
|
||||
token_exchange_method: z.nativeEnum(TokenExchangeMethodEnum).optional(),
|
||||
})
|
||||
.optional(),
|
||||
customUserVars: z
|
||||
.record(
|
||||
z.string(),
|
||||
z.object({
|
||||
title: z.string(),
|
||||
description: z.string(),
|
||||
}),
|
||||
)
|
||||
.optional(),
|
||||
});
|
||||
|
||||
export const StdioOptionsSchema = BaseOptionsSchema.extend({
|
||||
|
|
@ -114,12 +155,100 @@ export const MCPServersSchema = z.record(z.string(), MCPOptionsSchema);
|
|||
export type MCPOptions = z.infer<typeof MCPOptionsSchema>;
|
||||
|
||||
/**
|
||||
* Recursively processes an object to replace environment variables in string values
|
||||
* @param {MCPOptions} obj - The object to process
|
||||
* @param {string} [userId] - The user ID
|
||||
* @returns {MCPOptions} - The processed object with environment variables replaced
|
||||
* List of allowed user fields that can be used in MCP environment variables.
|
||||
* These are non-sensitive string/boolean fields from the IUser interface.
|
||||
*/
|
||||
export function processMCPEnv(obj: Readonly<MCPOptions>, userId?: string): MCPOptions {
|
||||
const ALLOWED_USER_FIELDS = [
|
||||
'name',
|
||||
'username',
|
||||
'email',
|
||||
'provider',
|
||||
'role',
|
||||
'googleId',
|
||||
'facebookId',
|
||||
'openidId',
|
||||
'samlId',
|
||||
'ldapId',
|
||||
'githubId',
|
||||
'discordId',
|
||||
'appleId',
|
||||
'emailVerified',
|
||||
'twoFactorEnabled',
|
||||
'termsAccepted',
|
||||
] as const;
|
||||
|
||||
/**
|
||||
* Processes a string value to replace user field placeholders
|
||||
* @param value - The string value to process
|
||||
* @param user - The user object
|
||||
* @returns The processed string with placeholders replaced
|
||||
*/
|
||||
function processUserPlaceholders(value: string, user?: TUser): string {
|
||||
if (!user || typeof value !== 'string') {
|
||||
return value;
|
||||
}
|
||||
|
||||
for (const field of ALLOWED_USER_FIELDS) {
|
||||
const placeholder = `{{LIBRECHAT_USER_${field.toUpperCase()}}}`;
|
||||
if (value.includes(placeholder)) {
|
||||
const fieldValue = user[field as keyof TUser];
|
||||
const replacementValue = fieldValue != null ? String(fieldValue) : '';
|
||||
value = value.replace(new RegExp(placeholder, 'g'), replacementValue);
|
||||
}
|
||||
}
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
function processSingleValue({
|
||||
originalValue,
|
||||
customUserVars,
|
||||
user,
|
||||
}: {
|
||||
originalValue: string;
|
||||
customUserVars?: Record<string, string>;
|
||||
user?: TUser;
|
||||
}): string {
|
||||
let value = originalValue;
|
||||
|
||||
// 1. Replace custom user variables
|
||||
if (customUserVars) {
|
||||
for (const [varName, varVal] of Object.entries(customUserVars)) {
|
||||
/** Escaped varName for use in regex to avoid issues with special characters */
|
||||
const escapedVarName = varName.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
const placeholderRegex = new RegExp(`\\{\\{${escapedVarName}\\}\\}`, 'g');
|
||||
value = value.replace(placeholderRegex, varVal);
|
||||
}
|
||||
}
|
||||
|
||||
// 2.A. Special handling for LIBRECHAT_USER_ID placeholder
|
||||
// This ensures {{LIBRECHAT_USER_ID}} is replaced only if user.id is available.
|
||||
// If user.id is null/undefined, the placeholder remains
|
||||
if (user && user.id != null && value.includes('{{LIBRECHAT_USER_ID}}')) {
|
||||
value = value.replace(/\{\{LIBRECHAT_USER_ID\}\}/g, String(user.id));
|
||||
}
|
||||
|
||||
// 2.B. Replace other standard user field placeholders (e.g., {{LIBRECHAT_USER_EMAIL}})
|
||||
value = processUserPlaceholders(value, user);
|
||||
|
||||
// 3. Replace system environment variables
|
||||
value = extractEnvVariable(value);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively processes an object to replace environment variables in string values
|
||||
* @param obj - The object to process
|
||||
* @param user - The user object containing all user fields
|
||||
* @param customUserVars - vars that user set in settings
|
||||
* @returns - The processed object with environment variables replaced
|
||||
*/
|
||||
export function processMCPEnv(
|
||||
obj: Readonly<MCPOptions>,
|
||||
user?: TUser,
|
||||
customUserVars?: Record<string, string>,
|
||||
): MCPOptions {
|
||||
if (obj === null || obj === undefined) {
|
||||
return obj;
|
||||
}
|
||||
|
|
@ -128,24 +257,25 @@ export function processMCPEnv(obj: Readonly<MCPOptions>, userId?: string): MCPOp
|
|||
|
||||
if ('env' in newObj && newObj.env) {
|
||||
const processedEnv: Record<string, string> = {};
|
||||
for (const [key, value] of Object.entries(newObj.env)) {
|
||||
processedEnv[key] = extractEnvVariable(value);
|
||||
for (const [key, originalValue] of Object.entries(newObj.env)) {
|
||||
processedEnv[key] = processSingleValue({ originalValue, customUserVars, user });
|
||||
}
|
||||
newObj.env = processedEnv;
|
||||
} else if ('headers' in newObj && newObj.headers) {
|
||||
}
|
||||
|
||||
// Process headers if they exist (for WebSocket, SSE, StreamableHTTP types)
|
||||
// Note: `env` and `headers` are on different branches of the MCPOptions union type.
|
||||
if ('headers' in newObj && newObj.headers) {
|
||||
const processedHeaders: Record<string, string> = {};
|
||||
for (const [key, value] of Object.entries(newObj.headers)) {
|
||||
if (value === '{{LIBRECHAT_USER_ID}}' && userId != null && userId) {
|
||||
processedHeaders[key] = userId;
|
||||
continue;
|
||||
}
|
||||
processedHeaders[key] = extractEnvVariable(value);
|
||||
for (const [key, originalValue] of Object.entries(newObj.headers)) {
|
||||
processedHeaders[key] = processSingleValue({ originalValue, customUserVars, user });
|
||||
}
|
||||
newObj.headers = processedHeaders;
|
||||
}
|
||||
|
||||
// Process URL if it exists (for WebSocket, SSE, StreamableHTTP types)
|
||||
if ('url' in newObj && newObj.url) {
|
||||
newObj.url = extractEnvVariable(newObj.url);
|
||||
newObj.url = processSingleValue({ originalValue: newObj.url, customUserVars, user });
|
||||
}
|
||||
|
||||
return newObj;
|
||||
|
|
|
|||
62
packages/data-provider/src/memory.ts
Normal file
62
packages/data-provider/src/memory.ts
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
import type { TCustomConfig, TMemoryConfig } from './config';
|
||||
|
||||
/**
|
||||
* Loads the memory configuration and validates it
|
||||
* @param config - The memory configuration from librechat.yaml
|
||||
* @returns The validated memory configuration
|
||||
*/
|
||||
export function loadMemoryConfig(config: TCustomConfig['memory']): TMemoryConfig | undefined {
|
||||
if (!config) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// If disabled is explicitly true, return the config as-is
|
||||
if (config.disabled === true) {
|
||||
return config;
|
||||
}
|
||||
|
||||
// Check if the agent configuration is valid
|
||||
const hasValidAgent =
|
||||
config.agent &&
|
||||
(('id' in config.agent && !!config.agent.id) ||
|
||||
('provider' in config.agent &&
|
||||
'model' in config.agent &&
|
||||
!!config.agent.provider &&
|
||||
!!config.agent.model));
|
||||
|
||||
// If agent config is invalid, treat as disabled
|
||||
if (!hasValidAgent) {
|
||||
return {
|
||||
...config,
|
||||
disabled: true,
|
||||
};
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if memory feature is enabled based on the configuration
|
||||
* @param config - The memory configuration
|
||||
* @returns True if memory is enabled, false otherwise
|
||||
*/
|
||||
export function isMemoryEnabled(config: TMemoryConfig | undefined): boolean {
|
||||
if (!config) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (config.disabled === true) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if agent configuration is valid
|
||||
const hasValidAgent =
|
||||
config.agent &&
|
||||
(('id' in config.agent && !!config.agent.id) ||
|
||||
('provider' in config.agent &&
|
||||
'model' in config.agent &&
|
||||
!!config.agent.provider &&
|
||||
!!config.agent.model));
|
||||
|
||||
return !!hasValidAgent;
|
||||
}
|
||||
|
|
@ -225,13 +225,15 @@ const extractOmniVersion = (modelStr: string): string => {
|
|||
export const getResponseSender = (endpointOption: t.TEndpointOption): string => {
|
||||
const {
|
||||
model: _m,
|
||||
endpoint,
|
||||
endpoint: _e,
|
||||
endpointType,
|
||||
modelDisplayLabel: _mdl,
|
||||
chatGptLabel: _cgl,
|
||||
modelLabel: _ml,
|
||||
} = endpointOption;
|
||||
|
||||
const endpoint = _e as EModelEndpoint;
|
||||
|
||||
const model = _m ?? '';
|
||||
const modelDisplayLabel = _mdl ?? '';
|
||||
const chatGptLabel = _cgl ?? '';
|
||||
|
|
|
|||
|
|
@ -16,6 +16,10 @@ export enum PermissionTypes {
|
|||
* Type for Agent Permissions
|
||||
*/
|
||||
AGENTS = 'AGENTS',
|
||||
/**
|
||||
* Type for Memory Permissions
|
||||
*/
|
||||
MEMORIES = 'MEMORIES',
|
||||
/**
|
||||
* Type for Multi-Conversation Permissions
|
||||
*/
|
||||
|
|
@ -45,6 +49,8 @@ export enum Permissions {
|
|||
READ = 'READ',
|
||||
READ_AUTHOR = 'READ_AUTHOR',
|
||||
SHARE = 'SHARE',
|
||||
/** Can disable if desired */
|
||||
OPT_OUT = 'OPT_OUT',
|
||||
}
|
||||
|
||||
export const promptPermissionsSchema = z.object({
|
||||
|
|
@ -60,6 +66,15 @@ export const bookmarkPermissionsSchema = z.object({
|
|||
});
|
||||
export type TBookmarkPermissions = z.infer<typeof bookmarkPermissionsSchema>;
|
||||
|
||||
export const memoryPermissionsSchema = z.object({
|
||||
[Permissions.USE]: z.boolean().default(true),
|
||||
[Permissions.CREATE]: z.boolean().default(true),
|
||||
[Permissions.UPDATE]: z.boolean().default(true),
|
||||
[Permissions.READ]: z.boolean().default(true),
|
||||
[Permissions.OPT_OUT]: z.boolean().default(true),
|
||||
});
|
||||
export type TMemoryPermissions = z.infer<typeof memoryPermissionsSchema>;
|
||||
|
||||
export const agentPermissionsSchema = z.object({
|
||||
[Permissions.SHARED_GLOBAL]: z.boolean().default(false),
|
||||
[Permissions.USE]: z.boolean().default(true),
|
||||
|
|
@ -92,6 +107,7 @@ export type TWebSearchPermissions = z.infer<typeof webSearchPermissionsSchema>;
|
|||
export const permissionsSchema = z.object({
|
||||
[PermissionTypes.PROMPTS]: promptPermissionsSchema,
|
||||
[PermissionTypes.BOOKMARKS]: bookmarkPermissionsSchema,
|
||||
[PermissionTypes.MEMORIES]: memoryPermissionsSchema,
|
||||
[PermissionTypes.AGENTS]: agentPermissionsSchema,
|
||||
[PermissionTypes.MULTI_CONVO]: multiConvoPermissionsSchema,
|
||||
[PermissionTypes.TEMPORARY_CHAT]: temporaryChatPermissionsSchema,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import {
|
|||
permissionsSchema,
|
||||
agentPermissionsSchema,
|
||||
promptPermissionsSchema,
|
||||
memoryPermissionsSchema,
|
||||
runCodePermissionsSchema,
|
||||
webSearchPermissionsSchema,
|
||||
bookmarkPermissionsSchema,
|
||||
|
|
@ -48,6 +49,13 @@ const defaultRolesSchema = z.object({
|
|||
[PermissionTypes.BOOKMARKS]: bookmarkPermissionsSchema.extend({
|
||||
[Permissions.USE]: z.boolean().default(true),
|
||||
}),
|
||||
[PermissionTypes.MEMORIES]: memoryPermissionsSchema.extend({
|
||||
[Permissions.USE]: z.boolean().default(true),
|
||||
[Permissions.CREATE]: z.boolean().default(true),
|
||||
[Permissions.UPDATE]: z.boolean().default(true),
|
||||
[Permissions.READ]: z.boolean().default(true),
|
||||
[Permissions.OPT_OUT]: z.boolean().default(true),
|
||||
}),
|
||||
[PermissionTypes.AGENTS]: agentPermissionsSchema.extend({
|
||||
[Permissions.SHARED_GLOBAL]: z.boolean().default(true),
|
||||
[Permissions.USE]: z.boolean().default(true),
|
||||
|
|
@ -86,6 +94,13 @@ export const roleDefaults = defaultRolesSchema.parse({
|
|||
[PermissionTypes.BOOKMARKS]: {
|
||||
[Permissions.USE]: true,
|
||||
},
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.UPDATE]: true,
|
||||
[Permissions.READ]: true,
|
||||
[Permissions.OPT_OUT]: true,
|
||||
},
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
[Permissions.USE]: true,
|
||||
|
|
@ -110,6 +125,7 @@ export const roleDefaults = defaultRolesSchema.parse({
|
|||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {},
|
||||
[PermissionTypes.BOOKMARKS]: {},
|
||||
[PermissionTypes.MEMORIES]: {},
|
||||
[PermissionTypes.AGENTS]: {},
|
||||
[PermissionTypes.MULTI_CONVO]: {},
|
||||
[PermissionTypes.TEMPORARY_CHAT]: {},
|
||||
|
|
|
|||
|
|
@ -522,11 +522,19 @@ export const tMessageSchema = z.object({
|
|||
feedback: feedbackSchema.optional(),
|
||||
});
|
||||
|
||||
export type MemoryArtifact = {
|
||||
key: string;
|
||||
value?: string;
|
||||
tokenCount?: number;
|
||||
type: 'update' | 'delete';
|
||||
};
|
||||
|
||||
export type TAttachmentMetadata = {
|
||||
type?: Tools;
|
||||
messageId: string;
|
||||
toolCallId: string;
|
||||
[Tools.web_search]?: SearchResultData;
|
||||
[Tools.memory]?: MemoryArtifact;
|
||||
};
|
||||
|
||||
export type TAttachment =
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
import type OpenAI from 'openai';
|
||||
import type { InfiniteData } from '@tanstack/react-query';
|
||||
import type {
|
||||
TBanner,
|
||||
TMessage,
|
||||
TResPlugin,
|
||||
ImageDetail,
|
||||
TSharedLink,
|
||||
TConversation,
|
||||
EModelEndpoint,
|
||||
TConversationTag,
|
||||
TBanner,
|
||||
TAttachment,
|
||||
} from './schemas';
|
||||
import { TMinimalFeedback } from './feedback';
|
||||
import { SettingDefinition } from './generate';
|
||||
import type { SettingDefinition } from './generate';
|
||||
import type { TMinimalFeedback } from './feedback';
|
||||
import type { Agent } from './types/assistants';
|
||||
|
||||
export type TOpenAIMessage = OpenAI.Chat.ChatCompletionMessageParam;
|
||||
|
||||
|
|
@ -20,28 +21,78 @@ export * from './schemas';
|
|||
export type TMessages = TMessage[];
|
||||
|
||||
/* TODO: Cleanup EndpointOption types */
|
||||
export type TEndpointOption = {
|
||||
spec?: string | null;
|
||||
iconURL?: string | null;
|
||||
endpoint: EModelEndpoint;
|
||||
endpointType?: EModelEndpoint;
|
||||
export type TEndpointOption = Pick<
|
||||
TConversation,
|
||||
// Core conversation fields
|
||||
| 'endpoint'
|
||||
| 'endpointType'
|
||||
| 'model'
|
||||
| 'modelLabel'
|
||||
| 'chatGptLabel'
|
||||
| 'promptPrefix'
|
||||
| 'temperature'
|
||||
| 'topP'
|
||||
| 'topK'
|
||||
| 'top_p'
|
||||
| 'frequency_penalty'
|
||||
| 'presence_penalty'
|
||||
| 'maxOutputTokens'
|
||||
| 'maxContextTokens'
|
||||
| 'max_tokens'
|
||||
| 'maxTokens'
|
||||
| 'resendFiles'
|
||||
| 'imageDetail'
|
||||
| 'reasoning_effort'
|
||||
| 'instructions'
|
||||
| 'additional_instructions'
|
||||
| 'append_current_datetime'
|
||||
| 'tools'
|
||||
| 'stop'
|
||||
| 'region'
|
||||
| 'additionalModelRequestFields'
|
||||
// Anthropic-specific
|
||||
| 'promptCache'
|
||||
| 'thinking'
|
||||
| 'thinkingBudget'
|
||||
// Assistant/Agent fields
|
||||
| 'assistant_id'
|
||||
| 'agent_id'
|
||||
// UI/Display fields
|
||||
| 'iconURL'
|
||||
| 'greeting'
|
||||
| 'spec'
|
||||
// Artifacts
|
||||
| 'artifacts'
|
||||
// Files
|
||||
| 'file_ids'
|
||||
// System field
|
||||
| 'system'
|
||||
// Google examples
|
||||
| 'examples'
|
||||
// Context
|
||||
| 'context'
|
||||
> & {
|
||||
// Fields specific to endpoint options that don't exist on TConversation
|
||||
modelDisplayLabel?: string;
|
||||
resendFiles?: boolean;
|
||||
promptCache?: boolean;
|
||||
maxContextTokens?: number;
|
||||
imageDetail?: ImageDetail;
|
||||
model?: string | null;
|
||||
promptPrefix?: string;
|
||||
temperature?: number;
|
||||
chatGptLabel?: string | null;
|
||||
modelLabel?: string | null;
|
||||
jailbreak?: boolean;
|
||||
key?: string | null;
|
||||
/* assistant */
|
||||
/** @deprecated Assistants API */
|
||||
thread_id?: string;
|
||||
/* multi-response stream */
|
||||
// Conversation identifiers for multi-response streams
|
||||
overrideConvoId?: string;
|
||||
overrideUserMessageId?: string;
|
||||
// Model parameters (used by different endpoints)
|
||||
modelOptions?: Record<string, unknown>;
|
||||
model_parameters?: Record<string, unknown>;
|
||||
// Configuration data (added by middleware)
|
||||
modelsConfig?: TModelsConfig;
|
||||
// File attachments (processed by middleware)
|
||||
attachments?: TAttachment[];
|
||||
// Generated prompts
|
||||
artifactsPrompt?: string;
|
||||
// Agent-specific fields
|
||||
agent?: Promise<Agent>;
|
||||
// Client-specific options
|
||||
clientOptions?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type TEphemeralAgent = {
|
||||
|
|
@ -130,6 +181,9 @@ export type TUser = {
|
|||
plugins?: string[];
|
||||
twoFactorEnabled?: boolean;
|
||||
backupCodes?: TBackupCode[];
|
||||
personalization?: {
|
||||
memories?: boolean;
|
||||
};
|
||||
createdAt: string;
|
||||
updatedAt: string;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
/* eslint-disable @typescript-eslint/no-namespace */
|
||||
import { StepTypes, ContentTypes, ToolCallTypes } from './runs';
|
||||
import type { TAttachment, TPlugin } from 'src/schemas';
|
||||
import type { FunctionToolCall } from './assistants';
|
||||
import type { TAttachment } from 'src/schemas';
|
||||
|
||||
export namespace Agents {
|
||||
export type MessageType = 'human' | 'ai' | 'generic' | 'system' | 'function' | 'tool' | 'remove';
|
||||
|
|
@ -279,3 +279,79 @@ export type ToolCallResult = {
|
|||
conversationId: string;
|
||||
attachments?: TAttachment[];
|
||||
};
|
||||
|
||||
export enum AuthTypeEnum {
|
||||
ServiceHttp = 'service_http',
|
||||
OAuth = 'oauth',
|
||||
None = 'none',
|
||||
}
|
||||
|
||||
export enum AuthorizationTypeEnum {
|
||||
Bearer = 'bearer',
|
||||
Basic = 'basic',
|
||||
Custom = 'custom',
|
||||
}
|
||||
|
||||
export enum TokenExchangeMethodEnum {
|
||||
DefaultPost = 'default_post',
|
||||
BasicAuthHeader = 'basic_auth_header',
|
||||
}
|
||||
|
||||
export type Action = {
|
||||
action_id: string;
|
||||
type?: string;
|
||||
settings?: Record<string, unknown>;
|
||||
metadata: ActionMetadata;
|
||||
version: number | string;
|
||||
} & ({ assistant_id: string; agent_id?: never } | { assistant_id?: never; agent_id: string });
|
||||
|
||||
export type ActionMetadata = {
|
||||
api_key?: string;
|
||||
auth?: ActionAuth;
|
||||
domain?: string;
|
||||
privacy_policy_url?: string;
|
||||
raw_spec?: string;
|
||||
oauth_client_id?: string;
|
||||
oauth_client_secret?: string;
|
||||
};
|
||||
|
||||
export type ActionAuth = {
|
||||
authorization_type?: AuthorizationTypeEnum;
|
||||
custom_auth_header?: string;
|
||||
type?: AuthTypeEnum;
|
||||
authorization_content_type?: string;
|
||||
authorization_url?: string;
|
||||
client_url?: string;
|
||||
scope?: string;
|
||||
token_exchange_method?: TokenExchangeMethodEnum;
|
||||
};
|
||||
|
||||
export type ActionMetadataRuntime = ActionMetadata & {
|
||||
oauth_access_token?: string;
|
||||
oauth_refresh_token?: string;
|
||||
oauth_token_expires_at?: Date;
|
||||
};
|
||||
|
||||
export type MCP = {
|
||||
mcp_id: string;
|
||||
metadata: MCPMetadata;
|
||||
} & ({ assistant_id: string; agent_id?: never } | { assistant_id?: never; agent_id: string });
|
||||
|
||||
export type MCPMetadata = Omit<ActionMetadata, 'auth'> & {
|
||||
name?: string;
|
||||
description?: string;
|
||||
url?: string;
|
||||
tools?: string[];
|
||||
auth?: MCPAuth;
|
||||
icon?: string;
|
||||
trust?: boolean;
|
||||
};
|
||||
|
||||
export type MCPAuth = ActionAuth;
|
||||
|
||||
export type AgentToolType = {
|
||||
tool_id: string;
|
||||
metadata: ToolMetadata;
|
||||
} & ({ assistant_id: string; agent_id?: never } | { assistant_id?: never; agent_id: string });
|
||||
|
||||
export type ToolMetadata = TPlugin;
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ export enum Tools {
|
|||
web_search = 'web_search',
|
||||
retrieval = 'retrieval',
|
||||
function = 'function',
|
||||
memory = 'memory',
|
||||
}
|
||||
|
||||
export enum EToolResources {
|
||||
|
|
@ -486,60 +487,6 @@ export const actionDomainSeparator = '---';
|
|||
export const hostImageIdSuffix = '_host_copy';
|
||||
export const hostImageNamePrefix = 'host_copy_';
|
||||
|
||||
export enum AuthTypeEnum {
|
||||
ServiceHttp = 'service_http',
|
||||
OAuth = 'oauth',
|
||||
None = 'none',
|
||||
}
|
||||
|
||||
export enum AuthorizationTypeEnum {
|
||||
Bearer = 'bearer',
|
||||
Basic = 'basic',
|
||||
Custom = 'custom',
|
||||
}
|
||||
|
||||
export enum TokenExchangeMethodEnum {
|
||||
DefaultPost = 'default_post',
|
||||
BasicAuthHeader = 'basic_auth_header',
|
||||
}
|
||||
|
||||
export type ActionAuth = {
|
||||
authorization_type?: AuthorizationTypeEnum;
|
||||
custom_auth_header?: string;
|
||||
type?: AuthTypeEnum;
|
||||
authorization_content_type?: string;
|
||||
authorization_url?: string;
|
||||
client_url?: string;
|
||||
scope?: string;
|
||||
token_exchange_method?: TokenExchangeMethodEnum;
|
||||
};
|
||||
|
||||
export type ActionMetadata = {
|
||||
api_key?: string;
|
||||
auth?: ActionAuth;
|
||||
domain?: string;
|
||||
privacy_policy_url?: string;
|
||||
raw_spec?: string;
|
||||
oauth_client_id?: string;
|
||||
oauth_client_secret?: string;
|
||||
};
|
||||
|
||||
export type ActionMetadataRuntime = ActionMetadata & {
|
||||
oauth_access_token?: string;
|
||||
oauth_refresh_token?: string;
|
||||
oauth_token_expires_at?: Date;
|
||||
};
|
||||
|
||||
/* Assistant types */
|
||||
|
||||
export type Action = {
|
||||
action_id: string;
|
||||
type?: string;
|
||||
settings?: Record<string, unknown>;
|
||||
metadata: ActionMetadata;
|
||||
version: number | string;
|
||||
} & ({ assistant_id: string; agent_id?: never } | { assistant_id?: never; agent_id: string });
|
||||
|
||||
export type AssistantAvatar = {
|
||||
filepath: string;
|
||||
source: string;
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ export enum FileSources {
|
|||
vectordb = 'vectordb',
|
||||
execute_code = 'execute_code',
|
||||
mistral_ocr = 'mistral_ocr',
|
||||
azure_mistral_ocr = 'azure_mistral_ocr',
|
||||
text = 'text',
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,14 +6,13 @@ import {
|
|||
Assistant,
|
||||
AssistantCreateParams,
|
||||
AssistantUpdateParams,
|
||||
ActionMetadata,
|
||||
FunctionTool,
|
||||
AssistantDocument,
|
||||
Action,
|
||||
Agent,
|
||||
AgentCreateParams,
|
||||
AgentUpdateParams,
|
||||
} from './assistants';
|
||||
import { Action, ActionMetadata } from './agents';
|
||||
|
||||
export type MutationOptions<
|
||||
Response,
|
||||
|
|
@ -278,7 +277,7 @@ export type UpdatePermVars<T> = {
|
|||
};
|
||||
|
||||
export type UpdatePromptPermVars = UpdatePermVars<p.TPromptPermissions>;
|
||||
|
||||
export type UpdateMemoryPermVars = UpdatePermVars<p.TMemoryPermissions>;
|
||||
export type UpdateAgentPermVars = UpdatePermVars<p.TAgentPermissions>;
|
||||
|
||||
export type UpdatePermResponse = r.TRole;
|
||||
|
|
@ -290,6 +289,13 @@ export type UpdatePromptPermOptions = MutationOptions<
|
|||
types.TError | null | undefined
|
||||
>;
|
||||
|
||||
export type UpdateMemoryPermOptions = MutationOptions<
|
||||
UpdatePermResponse,
|
||||
UpdateMemoryPermVars,
|
||||
unknown,
|
||||
types.TError | null | undefined
|
||||
>;
|
||||
|
||||
export type UpdateAgentPermOptions = MutationOptions<
|
||||
UpdatePermResponse,
|
||||
UpdateAgentPermVars,
|
||||
|
|
|
|||
|
|
@ -109,3 +109,18 @@ export type VerifyToolAuthResponse = {
|
|||
|
||||
export type GetToolCallParams = { conversationId: string };
|
||||
export type ToolCallResults = a.ToolCallResult[];
|
||||
|
||||
/* Memories */
|
||||
export type TUserMemory = {
|
||||
key: string;
|
||||
value: string;
|
||||
updated_at: string;
|
||||
tokenCount?: number;
|
||||
};
|
||||
|
||||
export type MemoriesResponse = {
|
||||
memories: TUserMemory[];
|
||||
totalTokens: number;
|
||||
tokenLimit: number | null;
|
||||
usagePercentage: number | null;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import type {
|
|||
SearchProviders,
|
||||
TWebSearchConfig,
|
||||
} from './config';
|
||||
import { extractVariableName } from './utils';
|
||||
import { SearchCategories, SafeSearchTypes } from './config';
|
||||
import { extractVariableName } from './utils';
|
||||
import { AuthType } from './schemas';
|
||||
|
||||
export function loadWebSearchConfig(
|
||||
|
|
@ -64,23 +64,29 @@ export const webSearchAuth = {
|
|||
/**
|
||||
* Extracts all API keys from the webSearchAuth configuration object
|
||||
*/
|
||||
export const webSearchKeys: TWebSearchKeys[] = [];
|
||||
export function getWebSearchKeys(): TWebSearchKeys[] {
|
||||
const keys: TWebSearchKeys[] = [];
|
||||
|
||||
// Iterate through each category (providers, scrapers, rerankers)
|
||||
for (const category of Object.keys(webSearchAuth)) {
|
||||
const categoryObj = webSearchAuth[category as TWebSearchCategories];
|
||||
// Iterate through each category (providers, scrapers, rerankers)
|
||||
for (const category of Object.keys(webSearchAuth)) {
|
||||
const categoryObj = webSearchAuth[category as TWebSearchCategories];
|
||||
|
||||
// Iterate through each service within the category
|
||||
for (const service of Object.keys(categoryObj)) {
|
||||
const serviceObj = categoryObj[service as keyof typeof categoryObj];
|
||||
// Iterate through each service within the category
|
||||
for (const service of Object.keys(categoryObj)) {
|
||||
const serviceObj = categoryObj[service as keyof typeof categoryObj];
|
||||
|
||||
// Extract the API keys from the service
|
||||
for (const key of Object.keys(serviceObj)) {
|
||||
webSearchKeys.push(key as TWebSearchKeys);
|
||||
// Extract the API keys from the service
|
||||
for (const key of Object.keys(serviceObj)) {
|
||||
keys.push(key as TWebSearchKeys);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return keys;
|
||||
}
|
||||
|
||||
export const webSearchKeys: TWebSearchKeys[] = getWebSearchKeys();
|
||||
|
||||
export function extractWebSearchEnvVars({
|
||||
keys,
|
||||
config,
|
||||
|
|
|
|||
|
|
@ -1,114 +0,0 @@
|
|||
# `@librechat/data-schemas`
|
||||
|
||||
Mongoose schemas and models for LibreChat. This package provides a comprehensive collection of Mongoose schemas used across the LibreChat project, enabling robust data modeling and validation for various entities such as actions, agents, messages, users, and more.
|
||||
|
||||
|
||||
## Features
|
||||
|
||||
- **Modular Schemas:** Includes schemas for actions, agents, assistants, balance, banners, categories, conversation tags, conversations, files, keys, messages, plugin authentication, presets, projects, prompts, prompt groups, roles, sessions, shared links, tokens, tool calls, transactions, and users.
|
||||
- **TypeScript Support:** Provides TypeScript definitions for type-safe development.
|
||||
- **Ready for Mongoose Integration:** Easily integrate with Mongoose to create models and interact with your MongoDB database.
|
||||
- **Flexible & Extensible:** Designed to support the evolving needs of LibreChat while being adaptable to other projects.
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
Install the package via npm or yarn:
|
||||
|
||||
```bash
|
||||
npm install @librechat/data-schemas
|
||||
```
|
||||
|
||||
Or with yarn:
|
||||
|
||||
```bash
|
||||
yarn add @librechat/data-schemas
|
||||
```
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
After installation, you can import and use the schemas in your project. For example, to create a Mongoose model for a user:
|
||||
|
||||
```js
|
||||
import mongoose from 'mongoose';
|
||||
import { userSchema } from '@librechat/data-schemas';
|
||||
|
||||
const UserModel = mongoose.model('User', userSchema);
|
||||
|
||||
// Now you can use UserModel to create, read, update, and delete user documents.
|
||||
```
|
||||
|
||||
You can also import other schemas as needed:
|
||||
|
||||
```js
|
||||
import { actionSchema, agentSchema, messageSchema } from '@librechat/data-schemas';
|
||||
```
|
||||
|
||||
Each schema is designed to integrate seamlessly with Mongoose and provides indexes, timestamps, and validations tailored for LibreChat’s use cases.
|
||||
|
||||
|
||||
## Development
|
||||
|
||||
This package uses Rollup and TypeScript for building and bundling.
|
||||
|
||||
### Available Scripts
|
||||
|
||||
- **Build:**
|
||||
Cleans the `dist` directory and builds the package.
|
||||
```bash
|
||||
npm run build
|
||||
```
|
||||
|
||||
- **Build Watch:**
|
||||
Rebuilds automatically on file changes.
|
||||
```bash
|
||||
npm run build:watch
|
||||
```
|
||||
|
||||
- **Test:**
|
||||
Runs tests with coverage in watch mode.
|
||||
```bash
|
||||
npm run test
|
||||
```
|
||||
|
||||
- **Test (CI):**
|
||||
Runs tests with coverage for CI environments.
|
||||
```bash
|
||||
npm run test:ci
|
||||
```
|
||||
|
||||
- **Verify:**
|
||||
Runs tests in CI mode to verify code integrity.
|
||||
```bash
|
||||
npm run verify
|
||||
```
|
||||
|
||||
- **Clean:**
|
||||
Removes the `dist` directory.
|
||||
```bash
|
||||
npm run clean
|
||||
```
|
||||
|
||||
For those using Bun, equivalent scripts are available:
|
||||
- **Bun Clean:** `bun run b:clean`
|
||||
- **Bun Build:** `bun run b:build`
|
||||
|
||||
|
||||
## Repository & Issues
|
||||
|
||||
The source code is maintained on GitHub.
|
||||
- **Repository:** [LibreChat Repository](https://github.com/danny-avila/LibreChat.git)
|
||||
- **Issues & Bug Reports:** [LibreChat Issues](https://github.com/danny-avila/LibreChat/issues)
|
||||
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the [MIT License](LICENSE).
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions to improve and expand the data schemas are welcome. If you have suggestions, improvements, or bug fixes, please open an issue or submit a pull request on the [GitHub repository](https://github.com/danny-avila/LibreChat/issues).
|
||||
|
||||
For more detailed documentation on each schema and model, please refer to the source code or visit the [LibreChat website](https://librechat.ai).
|
||||
|
|
@ -5,6 +5,7 @@ export default {
|
|||
testResultsProcessor: 'jest-junit',
|
||||
moduleNameMapper: {
|
||||
'^@src/(.*)$': '<rootDir>/src/$1',
|
||||
'^~/(.*)$': '<rootDir>/src/$1',
|
||||
},
|
||||
// coverageThreshold: {
|
||||
// global: {
|
||||
|
|
@ -16,4 +17,4 @@ export default {
|
|||
// },
|
||||
restoreMocks: true,
|
||||
testTimeout: 15000,
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "@librechat/data-schemas",
|
||||
"version": "0.0.7",
|
||||
"version": "0.0.10",
|
||||
"description": "Mongoose schemas and models for LibreChat",
|
||||
"type": "module",
|
||||
"main": "dist/index.cjs",
|
||||
|
|
@ -51,6 +51,7 @@
|
|||
"@types/traverse": "^0.6.37",
|
||||
"jest": "^29.5.0",
|
||||
"jest-junit": "^16.0.0",
|
||||
"mongodb-memory-server": "^10.1.4",
|
||||
"rimraf": "^5.0.1",
|
||||
"rollup": "^4.22.4",
|
||||
"rollup-plugin-generate-package-json": "^3.2.0",
|
||||
|
|
@ -60,13 +61,14 @@
|
|||
"typescript": "^5.0.4"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"keyv": "^5.3.2",
|
||||
"mongoose": "^8.12.1",
|
||||
"librechat-data-provider": "*",
|
||||
"jsonwebtoken": "^9.0.2",
|
||||
"keyv": "^5.3.2",
|
||||
"klona": "^2.0.6",
|
||||
"librechat-data-provider": "*",
|
||||
"lodash": "^4.17.21",
|
||||
"meilisearch": "^0.38.0",
|
||||
"mongoose": "^8.12.1",
|
||||
"nanoid": "^3.3.7",
|
||||
"traverse": "^0.6.11",
|
||||
"winston": "^3.17.0",
|
||||
"winston-daily-rotate-file": "^5.0.0"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import path from 'path';
|
|||
import winston from 'winston';
|
||||
import 'winston-daily-rotate-file';
|
||||
|
||||
const logDir = path.join(__dirname, '..', 'logs');
|
||||
const logDir = path.join(__dirname, '..', '..', '..', 'api', 'logs');
|
||||
|
||||
const { NODE_ENV, DEBUG_LOGGING = 'false' } = process.env;
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,8 @@ import winston from 'winston';
|
|||
import 'winston-daily-rotate-file';
|
||||
import { redactFormat, redactMessage, debugTraverse, jsonTruncateFormat } from './parsers';
|
||||
|
||||
// Define log directory
|
||||
const logDir = path.join(__dirname, '..', 'logs');
|
||||
const logDir = path.join(__dirname, '..', '..', '..', 'api', 'logs');
|
||||
|
||||
// Type-safe environment variables
|
||||
const { NODE_ENV, DEBUG_LOGGING, CONSOLE_JSON, DEBUG_CONSOLE } = process.env;
|
||||
|
||||
const useConsoleJson = typeof CONSOLE_JSON === 'string' && CONSOLE_JSON.toLowerCase() === 'true';
|
||||
|
|
@ -15,7 +13,6 @@ const useDebugConsole = typeof DEBUG_CONSOLE === 'string' && DEBUG_CONSOLE.toLow
|
|||
|
||||
const useDebugLogging = typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING.toLowerCase() === 'true';
|
||||
|
||||
// Define custom log levels
|
||||
const levels: winston.config.AbstractConfigSetLevels = {
|
||||
error: 0,
|
||||
warn: 1,
|
||||
|
|
|
|||
|
|
@ -3,5 +3,6 @@ export * from './schema';
|
|||
export { createModels } from './models';
|
||||
export { createMethods } from './methods';
|
||||
export type * from './types';
|
||||
export type * from './methods';
|
||||
export { default as logger } from './config/winston';
|
||||
export { default as meiliLogger } from './config/meiliLogger';
|
||||
|
|
|
|||
|
|
@ -2,6 +2,10 @@ import { createUserMethods, type UserMethods } from './user';
|
|||
import { createSessionMethods, type SessionMethods } from './session';
|
||||
import { createTokenMethods, type TokenMethods } from './token';
|
||||
import { createRoleMethods, type RoleMethods } from './role';
|
||||
/* Memories */
|
||||
import { createMemoryMethods, type MemoryMethods } from './memory';
|
||||
import { createShareMethods, type ShareMethods } from './share';
|
||||
import { createPluginAuthMethods, type PluginAuthMethods } from './pluginAuth';
|
||||
|
||||
/**
|
||||
* Creates all database methods for all collections
|
||||
|
|
@ -12,7 +16,17 @@ export function createMethods(mongoose: typeof import('mongoose')) {
|
|||
...createSessionMethods(mongoose),
|
||||
...createTokenMethods(mongoose),
|
||||
...createRoleMethods(mongoose),
|
||||
...createMemoryMethods(mongoose),
|
||||
...createShareMethods(mongoose),
|
||||
...createPluginAuthMethods(mongoose),
|
||||
};
|
||||
}
|
||||
|
||||
export type AllMethods = UserMethods & SessionMethods & TokenMethods & RoleMethods;
|
||||
export type { MemoryMethods, ShareMethods, TokenMethods, PluginAuthMethods };
|
||||
export type AllMethods = UserMethods &
|
||||
SessionMethods &
|
||||
TokenMethods &
|
||||
RoleMethods &
|
||||
MemoryMethods &
|
||||
ShareMethods &
|
||||
PluginAuthMethods;
|
||||
|
|
|
|||
168
packages/data-schemas/src/methods/memory.ts
Normal file
168
packages/data-schemas/src/methods/memory.ts
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
import { Types } from 'mongoose';
|
||||
import logger from '~/config/winston';
|
||||
import type * as t from '~/types';
|
||||
|
||||
/**
|
||||
* Formats a date in YYYY-MM-DD format
|
||||
*/
|
||||
const formatDate = (date: Date): string => {
|
||||
return date.toISOString().split('T')[0];
|
||||
};
|
||||
|
||||
// Factory function that takes mongoose instance and returns the methods
|
||||
export function createMemoryMethods(mongoose: typeof import('mongoose')) {
|
||||
const MemoryEntry = mongoose.models.MemoryEntry;
|
||||
|
||||
/**
|
||||
* Creates a new memory entry for a user
|
||||
* Throws an error if a memory with the same key already exists
|
||||
*/
|
||||
async function createMemory({
|
||||
userId,
|
||||
key,
|
||||
value,
|
||||
tokenCount = 0,
|
||||
}: t.SetMemoryParams): Promise<t.MemoryResult> {
|
||||
try {
|
||||
if (key?.toLowerCase() === 'nothing') {
|
||||
return { ok: false };
|
||||
}
|
||||
|
||||
const existingMemory = await MemoryEntry.findOne({ userId, key });
|
||||
if (existingMemory) {
|
||||
throw new Error('Memory with this key already exists');
|
||||
}
|
||||
|
||||
await MemoryEntry.create({
|
||||
userId,
|
||||
key,
|
||||
value,
|
||||
tokenCount,
|
||||
updated_at: new Date(),
|
||||
});
|
||||
|
||||
return { ok: true };
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to create memory: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets or updates a memory entry for a user
|
||||
*/
|
||||
async function setMemory({
|
||||
userId,
|
||||
key,
|
||||
value,
|
||||
tokenCount = 0,
|
||||
}: t.SetMemoryParams): Promise<t.MemoryResult> {
|
||||
try {
|
||||
if (key?.toLowerCase() === 'nothing') {
|
||||
return { ok: false };
|
||||
}
|
||||
|
||||
await MemoryEntry.findOneAndUpdate(
|
||||
{ userId, key },
|
||||
{
|
||||
value,
|
||||
tokenCount,
|
||||
updated_at: new Date(),
|
||||
},
|
||||
{
|
||||
upsert: true,
|
||||
new: true,
|
||||
},
|
||||
);
|
||||
|
||||
return { ok: true };
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to set memory: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a specific memory entry for a user
|
||||
*/
|
||||
async function deleteMemory({ userId, key }: t.DeleteMemoryParams): Promise<t.MemoryResult> {
|
||||
try {
|
||||
const result = await MemoryEntry.findOneAndDelete({ userId, key });
|
||||
return { ok: !!result };
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to delete memory: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all memory entries for a user
|
||||
*/
|
||||
async function getAllUserMemories(
|
||||
userId: string | Types.ObjectId,
|
||||
): Promise<t.IMemoryEntryLean[]> {
|
||||
try {
|
||||
return (await MemoryEntry.find({ userId }).lean()) as t.IMemoryEntryLean[];
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to get all memories: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets and formats all memories for a user in two different formats
|
||||
*/
|
||||
async function getFormattedMemories({
|
||||
userId,
|
||||
}: t.GetFormattedMemoriesParams): Promise<t.FormattedMemoriesResult> {
|
||||
try {
|
||||
const memories = await getAllUserMemories(userId);
|
||||
|
||||
if (!memories || memories.length === 0) {
|
||||
return { withKeys: '', withoutKeys: '', totalTokens: 0 };
|
||||
}
|
||||
|
||||
const sortedMemories = memories.sort(
|
||||
(a, b) => new Date(a.updated_at!).getTime() - new Date(b.updated_at!).getTime(),
|
||||
);
|
||||
|
||||
const totalTokens = sortedMemories.reduce((sum, memory) => {
|
||||
return sum + (memory.tokenCount || 0);
|
||||
}, 0);
|
||||
|
||||
const withKeys = sortedMemories
|
||||
.map((memory, index) => {
|
||||
const date = formatDate(new Date(memory.updated_at!));
|
||||
const tokenInfo = memory.tokenCount ? ` [${memory.tokenCount} tokens]` : '';
|
||||
return `${index + 1}. [${date}]. ["key": "${memory.key}"]${tokenInfo}. ["value": "${memory.value}"]`;
|
||||
})
|
||||
.join('\n\n');
|
||||
|
||||
const withoutKeys = sortedMemories
|
||||
.map((memory, index) => {
|
||||
const date = formatDate(new Date(memory.updated_at!));
|
||||
return `${index + 1}. [${date}]. ${memory.value}`;
|
||||
})
|
||||
.join('\n\n');
|
||||
|
||||
return { withKeys, withoutKeys, totalTokens };
|
||||
} catch (error) {
|
||||
logger.error('Failed to get formatted memories:', error);
|
||||
return { withKeys: '', withoutKeys: '', totalTokens: 0 };
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
setMemory,
|
||||
createMemory,
|
||||
deleteMemory,
|
||||
getAllUserMemories,
|
||||
getFormattedMemories,
|
||||
};
|
||||
}
|
||||
|
||||
export type MemoryMethods = ReturnType<typeof createMemoryMethods>;
|
||||
140
packages/data-schemas/src/methods/pluginAuth.ts
Normal file
140
packages/data-schemas/src/methods/pluginAuth.ts
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
import type { DeleteResult, Model } from 'mongoose';
|
||||
import type { IPluginAuth } from '~/schema/pluginAuth';
|
||||
import type {
|
||||
FindPluginAuthsByKeysParams,
|
||||
UpdatePluginAuthParams,
|
||||
DeletePluginAuthParams,
|
||||
FindPluginAuthParams,
|
||||
} from '~/types';
|
||||
|
||||
// Factory function that takes mongoose instance and returns the methods
|
||||
export function createPluginAuthMethods(mongoose: typeof import('mongoose')) {
|
||||
const PluginAuth: Model<IPluginAuth> = mongoose.models.PluginAuth;
|
||||
|
||||
/**
|
||||
* Finds a single plugin auth entry by userId and authField
|
||||
*/
|
||||
async function findOnePluginAuth({
|
||||
userId,
|
||||
authField,
|
||||
}: FindPluginAuthParams): Promise<IPluginAuth | null> {
|
||||
try {
|
||||
return await PluginAuth.findOne({ userId, authField }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to find plugin auth: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds multiple plugin auth entries by userId and pluginKeys
|
||||
*/
|
||||
async function findPluginAuthsByKeys({
|
||||
userId,
|
||||
pluginKeys,
|
||||
}: FindPluginAuthsByKeysParams): Promise<IPluginAuth[]> {
|
||||
try {
|
||||
if (!pluginKeys || pluginKeys.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return await PluginAuth.find({
|
||||
userId,
|
||||
pluginKey: { $in: pluginKeys },
|
||||
}).lean();
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to find plugin auths: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates or creates a plugin auth entry
|
||||
*/
|
||||
async function updatePluginAuth({
|
||||
userId,
|
||||
authField,
|
||||
pluginKey,
|
||||
value,
|
||||
}: UpdatePluginAuthParams): Promise<IPluginAuth> {
|
||||
try {
|
||||
const existingAuth = await PluginAuth.findOne({ userId, pluginKey, authField }).lean();
|
||||
|
||||
if (existingAuth) {
|
||||
return await PluginAuth.findOneAndUpdate(
|
||||
{ userId, pluginKey, authField },
|
||||
{ $set: { value } },
|
||||
{ new: true, upsert: true },
|
||||
).lean();
|
||||
} else {
|
||||
const newPluginAuth = await new PluginAuth({
|
||||
userId,
|
||||
authField,
|
||||
value,
|
||||
pluginKey,
|
||||
});
|
||||
await newPluginAuth.save();
|
||||
return newPluginAuth.toObject();
|
||||
}
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to update plugin auth: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes plugin auth entries based on provided parameters
|
||||
*/
|
||||
async function deletePluginAuth({
|
||||
userId,
|
||||
authField,
|
||||
pluginKey,
|
||||
all = false,
|
||||
}: DeletePluginAuthParams): Promise<DeleteResult> {
|
||||
try {
|
||||
if (all) {
|
||||
const filter: DeletePluginAuthParams = { userId };
|
||||
if (pluginKey) {
|
||||
filter.pluginKey = pluginKey;
|
||||
}
|
||||
return await PluginAuth.deleteMany(filter);
|
||||
}
|
||||
|
||||
if (!authField) {
|
||||
throw new Error('authField is required when all is false');
|
||||
}
|
||||
|
||||
return await PluginAuth.deleteOne({ userId, authField });
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to delete plugin auth: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes all plugin auth entries for a user
|
||||
*/
|
||||
async function deleteAllUserPluginAuths(userId: string): Promise<DeleteResult> {
|
||||
try {
|
||||
return await PluginAuth.deleteMany({ userId });
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to delete all user plugin auths: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
findOnePluginAuth,
|
||||
findPluginAuthsByKeys,
|
||||
updatePluginAuth,
|
||||
deletePluginAuth,
|
||||
deleteAllUserPluginAuths,
|
||||
};
|
||||
}
|
||||
|
||||
export type PluginAuthMethods = ReturnType<typeof createPluginAuthMethods>;
|
||||
|
|
@ -13,7 +13,9 @@ export class SessionError extends Error {
|
|||
}
|
||||
|
||||
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
|
||||
const expires = eval(REFRESH_TOKEN_EXPIRY ?? '0') ?? 1000 * 60 * 60 * 24 * 7; // 7 days default
|
||||
const expires = REFRESH_TOKEN_EXPIRY
|
||||
? eval(REFRESH_TOKEN_EXPIRY)
|
||||
: 1000 * 60 * 60 * 24 * 7; // 7 days default
|
||||
|
||||
// Factory function that takes mongoose instance and returns the methods
|
||||
export function createSessionMethods(mongoose: typeof import('mongoose')) {
|
||||
|
|
|
|||
1043
packages/data-schemas/src/methods/share.test.ts
Normal file
1043
packages/data-schemas/src/methods/share.test.ts
Normal file
File diff suppressed because it is too large
Load diff
442
packages/data-schemas/src/methods/share.ts
Normal file
442
packages/data-schemas/src/methods/share.ts
Normal file
|
|
@ -0,0 +1,442 @@
|
|||
import { nanoid } from 'nanoid';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import type { FilterQuery, Model } from 'mongoose';
|
||||
import type { SchemaWithMeiliMethods } from '~/models/plugins/mongoMeili';
|
||||
import type * as t from '~/types';
|
||||
import logger from '~/config/winston';
|
||||
|
||||
class ShareServiceError extends Error {
|
||||
code: string;
|
||||
constructor(message: string, code: string) {
|
||||
super(message);
|
||||
this.name = 'ShareServiceError';
|
||||
this.code = code;
|
||||
}
|
||||
}
|
||||
|
||||
function memoizedAnonymizeId(prefix: string) {
|
||||
const memo = new Map<string, string>();
|
||||
return (id: string) => {
|
||||
if (!memo.has(id)) {
|
||||
memo.set(id, `${prefix}_${nanoid()}`);
|
||||
}
|
||||
return memo.get(id) as string;
|
||||
};
|
||||
}
|
||||
|
||||
const anonymizeConvoId = memoizedAnonymizeId('convo');
|
||||
const anonymizeAssistantId = memoizedAnonymizeId('a');
|
||||
const anonymizeMessageId = (id: string) =>
|
||||
id === Constants.NO_PARENT ? id : memoizedAnonymizeId('msg')(id);
|
||||
|
||||
function anonymizeConvo(conversation: Partial<t.IConversation> & Partial<t.ISharedLink>) {
|
||||
if (!conversation) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const newConvo = { ...conversation };
|
||||
if (newConvo.assistant_id) {
|
||||
newConvo.assistant_id = anonymizeAssistantId(newConvo.assistant_id);
|
||||
}
|
||||
return newConvo;
|
||||
}
|
||||
|
||||
function anonymizeMessages(messages: t.IMessage[], newConvoId: string): t.IMessage[] {
|
||||
if (!Array.isArray(messages)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const idMap = new Map<string, string>();
|
||||
return messages.map((message) => {
|
||||
const newMessageId = anonymizeMessageId(message.messageId);
|
||||
idMap.set(message.messageId, newMessageId);
|
||||
|
||||
type MessageAttachment = {
|
||||
messageId?: string;
|
||||
conversationId?: string;
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
const anonymizedAttachments = (message.attachments as MessageAttachment[])?.map(
|
||||
(attachment) => {
|
||||
return {
|
||||
...attachment,
|
||||
messageId: newMessageId,
|
||||
conversationId: newConvoId,
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
return {
|
||||
...message,
|
||||
messageId: newMessageId,
|
||||
parentMessageId:
|
||||
idMap.get(message.parentMessageId || '') ||
|
||||
anonymizeMessageId(message.parentMessageId || ''),
|
||||
conversationId: newConvoId,
|
||||
model: message.model?.startsWith('asst_')
|
||||
? anonymizeAssistantId(message.model)
|
||||
: message.model,
|
||||
attachments: anonymizedAttachments,
|
||||
} as t.IMessage;
|
||||
});
|
||||
}
|
||||
|
||||
/** Factory function that takes mongoose instance and returns the methods */
|
||||
export function createShareMethods(mongoose: typeof import('mongoose')) {
|
||||
/**
|
||||
* Get shared messages for a public share link
|
||||
*/
|
||||
async function getSharedMessages(shareId: string): Promise<t.SharedMessagesResult | null> {
|
||||
try {
|
||||
const SharedLink = mongoose.models.SharedLink as Model<t.ISharedLink>;
|
||||
const share = (await SharedLink.findOne({ shareId, isPublic: true })
|
||||
.populate({
|
||||
path: 'messages',
|
||||
select: '-_id -__v -user',
|
||||
})
|
||||
.select('-_id -__v -user')
|
||||
.lean()) as (t.ISharedLink & { messages: t.IMessage[] }) | null;
|
||||
|
||||
if (!share?.conversationId || !share.isPublic) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const newConvoId = anonymizeConvoId(share.conversationId);
|
||||
const result: t.SharedMessagesResult = {
|
||||
shareId: share.shareId || shareId,
|
||||
title: share.title,
|
||||
isPublic: share.isPublic,
|
||||
createdAt: share.createdAt,
|
||||
updatedAt: share.updatedAt,
|
||||
conversationId: newConvoId,
|
||||
messages: anonymizeMessages(share.messages, newConvoId),
|
||||
};
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[getSharedMessages] Error getting share link', {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
shareId,
|
||||
});
|
||||
throw new ShareServiceError('Error getting share link', 'SHARE_FETCH_ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get shared links for a specific user with pagination and search
|
||||
*/
|
||||
async function getSharedLinks(
|
||||
user: string,
|
||||
pageParam?: Date,
|
||||
pageSize: number = 10,
|
||||
isPublic: boolean = true,
|
||||
sortBy: string = 'createdAt',
|
||||
sortDirection: string = 'desc',
|
||||
search?: string,
|
||||
): Promise<t.SharedLinksResult> {
|
||||
try {
|
||||
const SharedLink = mongoose.models.SharedLink as Model<t.ISharedLink>;
|
||||
const Conversation = mongoose.models.Conversation as SchemaWithMeiliMethods;
|
||||
const query: FilterQuery<t.ISharedLink> = { user, isPublic };
|
||||
|
||||
if (pageParam) {
|
||||
if (sortDirection === 'desc') {
|
||||
query[sortBy] = { $lt: pageParam };
|
||||
} else {
|
||||
query[sortBy] = { $gt: pageParam };
|
||||
}
|
||||
}
|
||||
|
||||
if (search && search.trim()) {
|
||||
try {
|
||||
const searchResults = await Conversation.meiliSearch(search);
|
||||
|
||||
if (!searchResults?.hits?.length) {
|
||||
return {
|
||||
links: [],
|
||||
nextCursor: undefined,
|
||||
hasNextPage: false,
|
||||
};
|
||||
}
|
||||
|
||||
const conversationIds = searchResults.hits.map((hit) => hit.conversationId);
|
||||
query['conversationId'] = { $in: conversationIds };
|
||||
} catch (searchError) {
|
||||
logger.error('[getSharedLinks] Meilisearch error', {
|
||||
error: searchError instanceof Error ? searchError.message : 'Unknown error',
|
||||
user,
|
||||
});
|
||||
return {
|
||||
links: [],
|
||||
nextCursor: undefined,
|
||||
hasNextPage: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const sort: Record<string, 1 | -1> = {};
|
||||
sort[sortBy] = sortDirection === 'desc' ? -1 : 1;
|
||||
|
||||
const sharedLinks = await SharedLink.find(query)
|
||||
.sort(sort)
|
||||
.limit(pageSize + 1)
|
||||
.select('-__v -user')
|
||||
.lean();
|
||||
|
||||
const hasNextPage = sharedLinks.length > pageSize;
|
||||
const links = sharedLinks.slice(0, pageSize);
|
||||
|
||||
const nextCursor = hasNextPage
|
||||
? (links[links.length - 1][sortBy as keyof t.ISharedLink] as Date)
|
||||
: undefined;
|
||||
|
||||
return {
|
||||
links: links.map((link) => ({
|
||||
shareId: link.shareId || '',
|
||||
title: link?.title || 'Untitled',
|
||||
isPublic: link.isPublic,
|
||||
createdAt: link.createdAt || new Date(),
|
||||
conversationId: link.conversationId,
|
||||
})),
|
||||
nextCursor,
|
||||
hasNextPage,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[getSharedLinks] Error getting shares', {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
user,
|
||||
});
|
||||
throw new ShareServiceError('Error getting shares', 'SHARES_FETCH_ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete all shared links for a user
|
||||
*/
|
||||
async function deleteAllSharedLinks(user: string): Promise<t.DeleteAllSharesResult> {
|
||||
try {
|
||||
const SharedLink = mongoose.models.SharedLink as Model<t.ISharedLink>;
|
||||
const result = await SharedLink.deleteMany({ user });
|
||||
return {
|
||||
message: 'All shared links deleted successfully',
|
||||
deletedCount: result.deletedCount,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[deleteAllSharedLinks] Error deleting shared links', {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
user,
|
||||
});
|
||||
throw new ShareServiceError('Error deleting shared links', 'BULK_DELETE_ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new shared link for a conversation
|
||||
*/
|
||||
async function createSharedLink(
|
||||
user: string,
|
||||
conversationId: string,
|
||||
): Promise<t.CreateShareResult> {
|
||||
if (!user || !conversationId) {
|
||||
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
|
||||
}
|
||||
try {
|
||||
const Message = mongoose.models.Message as SchemaWithMeiliMethods;
|
||||
const SharedLink = mongoose.models.SharedLink as Model<t.ISharedLink>;
|
||||
const Conversation = mongoose.models.Conversation as SchemaWithMeiliMethods;
|
||||
|
||||
const [existingShare, conversationMessages] = await Promise.all([
|
||||
SharedLink.findOne({ conversationId, user, isPublic: true })
|
||||
.select('-_id -__v -user')
|
||||
.lean() as Promise<t.ISharedLink | null>,
|
||||
Message.find({ conversationId, user }).sort({ createdAt: 1 }).lean(),
|
||||
]);
|
||||
|
||||
if (existingShare && existingShare.isPublic) {
|
||||
logger.error('[createSharedLink] Share already exists', {
|
||||
user,
|
||||
conversationId,
|
||||
});
|
||||
throw new ShareServiceError('Share already exists', 'SHARE_EXISTS');
|
||||
} else if (existingShare) {
|
||||
await SharedLink.deleteOne({ conversationId, user });
|
||||
}
|
||||
|
||||
const conversation = (await Conversation.findOne({ conversationId, user }).lean()) as {
|
||||
title?: string;
|
||||
} | null;
|
||||
|
||||
// Check if user owns the conversation
|
||||
if (!conversation) {
|
||||
throw new ShareServiceError(
|
||||
'Conversation not found or access denied',
|
||||
'CONVERSATION_NOT_FOUND',
|
||||
);
|
||||
}
|
||||
|
||||
// Check if there are any messages to share
|
||||
if (!conversationMessages || conversationMessages.length === 0) {
|
||||
throw new ShareServiceError('No messages to share', 'NO_MESSAGES');
|
||||
}
|
||||
|
||||
const title = conversation.title || 'Untitled';
|
||||
|
||||
const shareId = nanoid();
|
||||
await SharedLink.create({
|
||||
shareId,
|
||||
conversationId,
|
||||
messages: conversationMessages,
|
||||
title,
|
||||
user,
|
||||
});
|
||||
|
||||
return { shareId, conversationId };
|
||||
} catch (error) {
|
||||
if (error instanceof ShareServiceError) {
|
||||
throw error;
|
||||
}
|
||||
logger.error('[createSharedLink] Error creating shared link', {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
user,
|
||||
conversationId,
|
||||
});
|
||||
throw new ShareServiceError('Error creating shared link', 'SHARE_CREATE_ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a shared link for a conversation
|
||||
*/
|
||||
async function getSharedLink(
|
||||
user: string,
|
||||
conversationId: string,
|
||||
): Promise<t.GetShareLinkResult> {
|
||||
if (!user || !conversationId) {
|
||||
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
|
||||
}
|
||||
|
||||
try {
|
||||
const SharedLink = mongoose.models.SharedLink as Model<t.ISharedLink>;
|
||||
const share = (await SharedLink.findOne({ conversationId, user, isPublic: true })
|
||||
.select('shareId -_id')
|
||||
.lean()) as { shareId?: string } | null;
|
||||
|
||||
if (!share) {
|
||||
return { shareId: null, success: false };
|
||||
}
|
||||
|
||||
return { shareId: share.shareId || null, success: true };
|
||||
} catch (error) {
|
||||
logger.error('[getSharedLink] Error getting shared link', {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
user,
|
||||
conversationId,
|
||||
});
|
||||
throw new ShareServiceError('Error getting shared link', 'SHARE_FETCH_ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a shared link with new messages
|
||||
*/
|
||||
async function updateSharedLink(user: string, shareId: string): Promise<t.UpdateShareResult> {
|
||||
if (!user || !shareId) {
|
||||
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
|
||||
}
|
||||
|
||||
try {
|
||||
const SharedLink = mongoose.models.SharedLink as Model<t.ISharedLink>;
|
||||
const Message = mongoose.models.Message as SchemaWithMeiliMethods;
|
||||
const share = (await SharedLink.findOne({ shareId, user })
|
||||
.select('-_id -__v -user')
|
||||
.lean()) as t.ISharedLink | null;
|
||||
|
||||
if (!share) {
|
||||
throw new ShareServiceError('Share not found', 'SHARE_NOT_FOUND');
|
||||
}
|
||||
|
||||
const updatedMessages = await Message.find({ conversationId: share.conversationId, user })
|
||||
.sort({ createdAt: 1 })
|
||||
.lean();
|
||||
|
||||
const newShareId = nanoid();
|
||||
const update = {
|
||||
messages: updatedMessages,
|
||||
user,
|
||||
shareId: newShareId,
|
||||
};
|
||||
|
||||
const updatedShare = (await SharedLink.findOneAndUpdate({ shareId, user }, update, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
runValidators: true,
|
||||
}).lean()) as t.ISharedLink | null;
|
||||
|
||||
if (!updatedShare) {
|
||||
throw new ShareServiceError('Share update failed', 'SHARE_UPDATE_ERROR');
|
||||
}
|
||||
|
||||
anonymizeConvo(updatedShare);
|
||||
|
||||
return { shareId: newShareId, conversationId: updatedShare.conversationId };
|
||||
} catch (error) {
|
||||
logger.error('[updateSharedLink] Error updating shared link', {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
user,
|
||||
shareId,
|
||||
});
|
||||
throw new ShareServiceError(
|
||||
error instanceof ShareServiceError ? error.message : 'Error updating shared link',
|
||||
error instanceof ShareServiceError ? error.code : 'SHARE_UPDATE_ERROR',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a shared link
|
||||
*/
|
||||
async function deleteSharedLink(
|
||||
user: string,
|
||||
shareId: string,
|
||||
): Promise<t.DeleteShareResult | null> {
|
||||
if (!user || !shareId) {
|
||||
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
|
||||
}
|
||||
|
||||
try {
|
||||
const SharedLink = mongoose.models.SharedLink as Model<t.ISharedLink>;
|
||||
const result = await SharedLink.findOneAndDelete({ shareId, user }).lean();
|
||||
|
||||
if (!result) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
success: true,
|
||||
shareId,
|
||||
message: 'Share deleted successfully',
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[deleteSharedLink] Error deleting shared link', {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
user,
|
||||
shareId,
|
||||
});
|
||||
throw new ShareServiceError('Error deleting shared link', 'SHARE_DELETE_ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
// Return all methods
|
||||
return {
|
||||
getSharedLink,
|
||||
getSharedLinks,
|
||||
createSharedLink,
|
||||
updateSharedLink,
|
||||
deleteSharedLink,
|
||||
getSharedMessages,
|
||||
deleteAllSharedLinks,
|
||||
};
|
||||
}
|
||||
|
||||
export type ShareMethods = ReturnType<typeof createShareMethods>;
|
||||
163
packages/data-schemas/src/methods/user.test.ts
Normal file
163
packages/data-schemas/src/methods/user.test.ts
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
import mongoose from 'mongoose';
|
||||
import { createUserMethods } from './user';
|
||||
import { signPayload } from '~/crypto';
|
||||
import type { IUser } from '~/types';
|
||||
|
||||
jest.mock('~/crypto', () => ({
|
||||
signPayload: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('User Methods', () => {
|
||||
const mockSignPayload = signPayload as jest.MockedFunction<typeof signPayload>;
|
||||
let userMethods: ReturnType<typeof createUserMethods>;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
userMethods = createUserMethods(mongoose);
|
||||
});
|
||||
|
||||
describe('generateToken', () => {
|
||||
const mockUser = {
|
||||
_id: 'user123',
|
||||
username: 'testuser',
|
||||
provider: 'local',
|
||||
email: 'test@example.com',
|
||||
name: 'Test User',
|
||||
avatar: '',
|
||||
role: 'user',
|
||||
emailVerified: false,
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
} as IUser;
|
||||
|
||||
afterEach(() => {
|
||||
delete process.env.SESSION_EXPIRY;
|
||||
delete process.env.JWT_SECRET;
|
||||
});
|
||||
|
||||
it('should default to 15 minutes when SESSION_EXPIRY is not set', async () => {
|
||||
process.env.JWT_SECRET = 'test-secret';
|
||||
mockSignPayload.mockResolvedValue('mocked-token');
|
||||
|
||||
await userMethods.generateToken(mockUser);
|
||||
|
||||
expect(mockSignPayload).toHaveBeenCalledWith({
|
||||
payload: {
|
||||
id: mockUser._id,
|
||||
username: mockUser.username,
|
||||
provider: mockUser.provider,
|
||||
email: mockUser.email,
|
||||
},
|
||||
secret: 'test-secret',
|
||||
expirationTime: 900, // 15 minutes in seconds
|
||||
});
|
||||
});
|
||||
|
||||
it('should default to 15 minutes when SESSION_EXPIRY is empty string', async () => {
|
||||
process.env.SESSION_EXPIRY = '';
|
||||
process.env.JWT_SECRET = 'test-secret';
|
||||
mockSignPayload.mockResolvedValue('mocked-token');
|
||||
|
||||
await userMethods.generateToken(mockUser);
|
||||
|
||||
expect(mockSignPayload).toHaveBeenCalledWith({
|
||||
payload: {
|
||||
id: mockUser._id,
|
||||
username: mockUser.username,
|
||||
provider: mockUser.provider,
|
||||
email: mockUser.email,
|
||||
},
|
||||
secret: 'test-secret',
|
||||
expirationTime: 900, // 15 minutes in seconds
|
||||
});
|
||||
});
|
||||
|
||||
it('should use custom expiry when SESSION_EXPIRY is set to a valid expression', async () => {
|
||||
process.env.SESSION_EXPIRY = '1000 * 60 * 30'; // 30 minutes
|
||||
process.env.JWT_SECRET = 'test-secret';
|
||||
mockSignPayload.mockResolvedValue('mocked-token');
|
||||
|
||||
await userMethods.generateToken(mockUser);
|
||||
|
||||
expect(mockSignPayload).toHaveBeenCalledWith({
|
||||
payload: {
|
||||
id: mockUser._id,
|
||||
username: mockUser.username,
|
||||
provider: mockUser.provider,
|
||||
email: mockUser.email,
|
||||
},
|
||||
secret: 'test-secret',
|
||||
expirationTime: 1800, // 30 minutes in seconds
|
||||
});
|
||||
});
|
||||
|
||||
it('should default to 15 minutes when SESSION_EXPIRY evaluates to falsy value', async () => {
|
||||
process.env.SESSION_EXPIRY = '0'; // This will evaluate to 0, which is falsy
|
||||
process.env.JWT_SECRET = 'test-secret';
|
||||
mockSignPayload.mockResolvedValue('mocked-token');
|
||||
|
||||
await userMethods.generateToken(mockUser);
|
||||
|
||||
expect(mockSignPayload).toHaveBeenCalledWith({
|
||||
payload: {
|
||||
id: mockUser._id,
|
||||
username: mockUser.username,
|
||||
provider: mockUser.provider,
|
||||
email: mockUser.email,
|
||||
},
|
||||
secret: 'test-secret',
|
||||
expirationTime: 900, // 15 minutes in seconds
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw error when no user is provided', async () => {
|
||||
process.env.JWT_SECRET = 'test-secret';
|
||||
|
||||
await expect(userMethods.generateToken(null as unknown as IUser)).rejects.toThrow(
|
||||
'No user provided',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the token from signPayload', async () => {
|
||||
process.env.SESSION_EXPIRY = '1000 * 60 * 60'; // 1 hour
|
||||
process.env.JWT_SECRET = 'test-secret';
|
||||
const expectedToken = 'generated-jwt-token';
|
||||
mockSignPayload.mockResolvedValue(expectedToken);
|
||||
|
||||
const token = await userMethods.generateToken(mockUser);
|
||||
|
||||
expect(token).toBe(expectedToken);
|
||||
});
|
||||
|
||||
it('should handle invalid SESSION_EXPIRY expressions gracefully', async () => {
|
||||
process.env.SESSION_EXPIRY = 'invalid expression';
|
||||
process.env.JWT_SECRET = 'test-secret';
|
||||
mockSignPayload.mockResolvedValue('mocked-token');
|
||||
|
||||
// Mock console.warn to verify it's called
|
||||
const consoleWarnSpy = jest.spyOn(console, 'warn').mockImplementation();
|
||||
|
||||
await userMethods.generateToken(mockUser);
|
||||
|
||||
// Should use default value when eval fails
|
||||
expect(mockSignPayload).toHaveBeenCalledWith({
|
||||
payload: {
|
||||
id: mockUser._id,
|
||||
username: mockUser.username,
|
||||
provider: mockUser.provider,
|
||||
email: mockUser.email,
|
||||
},
|
||||
secret: 'test-secret',
|
||||
expirationTime: 900, // 15 minutes in seconds (default)
|
||||
});
|
||||
|
||||
// Verify warning was logged
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Invalid SESSION_EXPIRY expression, using default:',
|
||||
expect.any(SyntaxError),
|
||||
);
|
||||
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -145,7 +145,18 @@ export function createUserMethods(mongoose: typeof import('mongoose')) {
|
|||
throw new Error('No user provided');
|
||||
}
|
||||
|
||||
const expires = eval(process.env.SESSION_EXPIRY ?? '0') ?? 1000 * 60 * 15;
|
||||
let expires = 1000 * 60 * 15;
|
||||
|
||||
if (process.env.SESSION_EXPIRY !== undefined && process.env.SESSION_EXPIRY !== '') {
|
||||
try {
|
||||
const evaluated = eval(process.env.SESSION_EXPIRY);
|
||||
if (evaluated) {
|
||||
expires = evaluated;
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Invalid SESSION_EXPIRY expression, using default:', error);
|
||||
}
|
||||
}
|
||||
|
||||
return await signPayload({
|
||||
payload: {
|
||||
|
|
@ -159,6 +170,35 @@ export function createUserMethods(mongoose: typeof import('mongoose')) {
|
|||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a user's personalization memories setting.
|
||||
* Handles the edge case where the personalization object doesn't exist.
|
||||
*/
|
||||
async function toggleUserMemories(
|
||||
userId: string,
|
||||
memoriesEnabled: boolean,
|
||||
): Promise<IUser | null> {
|
||||
const User = mongoose.models.User;
|
||||
|
||||
// First, ensure the personalization object exists
|
||||
const user = await User.findById(userId);
|
||||
if (!user) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Use $set to update the nested field, which will create the personalization object if it doesn't exist
|
||||
const updateOperation = {
|
||||
$set: {
|
||||
'personalization.memories': memoriesEnabled,
|
||||
},
|
||||
};
|
||||
|
||||
return (await User.findByIdAndUpdate(userId, updateOperation, {
|
||||
new: true,
|
||||
runValidators: true,
|
||||
}).lean()) as IUser | null;
|
||||
}
|
||||
|
||||
// Return all methods
|
||||
return {
|
||||
findUser,
|
||||
|
|
@ -168,6 +208,7 @@ export function createUserMethods(mongoose: typeof import('mongoose')) {
|
|||
getUserById,
|
||||
deleteUserById,
|
||||
generateToken,
|
||||
toggleUserMemories,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue