mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-02-19 08:58:09 +01:00
📦 refactor: Consolidate DB models, encapsulating Mongoose usage in data-schemas (#11830)
* chore: move database model methods to /packages/data-schemas * chore: add TypeScript ESLint rule to warn on unused variables * refactor: model imports to streamline access - Consolidated model imports across various files to improve code organization and reduce redundancy. - Updated imports for models such as Assistant, Message, Conversation, and others to a unified import path. - Adjusted middleware and service files to reflect the new import structure, ensuring functionality remains intact. - Enhanced test files to align with the new import paths, maintaining test coverage and integrity. * chore: migrate database models to packages/data-schemas and refactor all direct Mongoose Model usage outside of data-schemas * test: update agent model mocks in unit tests - Added `getAgent` mock to `client.test.js` to enhance test coverage for agent-related functionality. - Removed redundant `getAgent` and `getAgents` mocks from `openai.spec.js` and `responses.unit.spec.js` to streamline test setup and reduce duplication. - Ensured consistency in agent mock implementations across test files. * fix: update types in data-schemas * refactor: enhance type definitions in transaction and spending methods - Updated type definitions in `checkBalance.ts` to use specific request and response types. - Refined `spendTokens.ts` to utilize a new `SpendTxData` interface for better clarity and type safety. - Improved transaction handling in `transaction.ts` by introducing `TransactionResult` and `TxData` interfaces, ensuring consistent data structures across methods. - Adjusted unit tests in `transaction.spec.ts` to accommodate new type definitions and enhance robustness. * refactor: streamline model imports and enhance code organization - Consolidated model imports across various controllers and services to a unified import path, improving code clarity and reducing redundancy. - Updated multiple files to reflect the new import structure, ensuring all functionalities remain intact. - Enhanced overall code organization by removing duplicate import statements and optimizing the usage of model methods. * feat: implement loadAddedAgent and refactor agent loading logic - Introduced `loadAddedAgent` function to handle loading agents from added conversations, supporting multi-convo parallel execution. - Created a new `load.ts` file to encapsulate agent loading functionalities, including `loadEphemeralAgent` and `loadAgent`. - Updated the `index.ts` file to export the new `load` module instead of the deprecated `loadAgent`. - Enhanced type definitions and improved error handling in the agent loading process. - Adjusted unit tests to reflect changes in the agent loading structure and ensure comprehensive coverage. * refactor: enhance balance handling with new update interface - Introduced `IBalanceUpdate` interface to streamline balance update operations across the codebase. - Updated `upsertBalanceFields` method signatures in `balance.ts`, `transaction.ts`, and related tests to utilize the new interface for improved type safety. - Adjusted type imports in `balance.spec.ts` to include `IBalanceUpdate`, ensuring consistency in balance management functionalities. - Enhanced overall code clarity and maintainability by refining type definitions related to balance operations. * feat: add unit tests for loadAgent functionality and enhance agent loading logic - Introduced comprehensive unit tests for the `loadAgent` function, covering various scenarios including null and empty agent IDs, loading of ephemeral agents, and permission checks. - Enhanced the `initializeClient` function by moving `getConvoFiles` to the correct position in the database method exports, ensuring proper functionality. - Improved test coverage for agent loading, including handling of non-existent agents and user permissions. * chore: reorder memory method exports for consistency - Moved `deleteAllUserMemories` to the correct position in the exported memory methods, ensuring a consistent and logical order of method exports in `memory.ts`.
This commit is contained in:
parent
a85e99ff45
commit
a6fb257bcf
182 changed files with 8548 additions and 8105 deletions
397
packages/api/src/agents/__tests__/load.spec.ts
Normal file
397
packages/api/src/agents/__tests__/load.spec.ts
Normal file
|
|
@ -0,0 +1,397 @@
|
|||
import mongoose from 'mongoose';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import { agentSchema, createMethods } from '@librechat/data-schemas';
|
||||
import type { AgentModelParameters } from 'librechat-data-provider';
|
||||
import type { LoadAgentParams, LoadAgentDeps } from '../load';
|
||||
import { loadAgent } from '../load';
|
||||
|
||||
let Agent: mongoose.Model<unknown>;
|
||||
let createAgent: ReturnType<typeof createMethods>['createAgent'];
|
||||
let getAgent: ReturnType<typeof createMethods>['getAgent'];
|
||||
|
||||
const mockGetMCPServerTools = jest.fn();
|
||||
|
||||
const deps: LoadAgentDeps = {
|
||||
getAgent: (searchParameter) => getAgent(searchParameter),
|
||||
getMCPServerTools: mockGetMCPServerTools,
|
||||
};
|
||||
|
||||
describe('loadAgent', () => {
|
||||
let mongoServer: MongoMemoryServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
const methods = createMethods(mongoose);
|
||||
createAgent = methods.createAgent;
|
||||
getAgent = methods.getAgent;
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await Agent.deleteMany({});
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
test('should return null when agent_id is not provided', async () => {
|
||||
const mockReq = { user: { id: 'user123' } };
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: null as unknown as string,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
test('should return null when agent_id is empty string', async () => {
|
||||
const mockReq = { user: { id: 'user123' } };
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: '',
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
test('should test ephemeral agent loading logic', async () => {
|
||||
const { EPHEMERAL_AGENT_ID } = Constants;
|
||||
|
||||
// Mock getMCPServerTools to return tools for each server
|
||||
mockGetMCPServerTools.mockImplementation(async (_userId: string, server: string) => {
|
||||
if (server === 'server1') {
|
||||
return { tool1_mcp_server1: {} };
|
||||
} else if (server === 'server2') {
|
||||
return { tool2_mcp_server2: {} };
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
promptPrefix: 'Test instructions',
|
||||
ephemeralAgent: {
|
||||
execute_code: true,
|
||||
web_search: true,
|
||||
mcp: ['server1', 'server2'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: EPHEMERAL_AGENT_ID as string,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4', temperature: 0.7 } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
if (result) {
|
||||
// Ephemeral agent ID is encoded with endpoint and model
|
||||
expect(result.id).toBe('openai__gpt-4');
|
||||
expect(result.instructions).toBe('Test instructions');
|
||||
expect(result.provider).toBe('openai');
|
||||
expect(result.model).toBe('gpt-4');
|
||||
expect(result.model_parameters.temperature).toBe(0.7);
|
||||
expect(result.tools).toContain('execute_code');
|
||||
expect(result.tools).toContain('web_search');
|
||||
expect(result.tools).toContain('tool1_mcp_server1');
|
||||
expect(result.tools).toContain('tool2_mcp_server2');
|
||||
} else {
|
||||
expect(result).toBeNull();
|
||||
}
|
||||
});
|
||||
|
||||
test('should return null for non-existent agent', async () => {
|
||||
const mockReq = { user: { id: 'user123' } };
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: 'agent_non_existent',
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
test('should load agent when user is the author', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userId,
|
||||
description: 'Test description',
|
||||
tools: ['web_search'],
|
||||
});
|
||||
|
||||
const mockReq = { user: { id: userId.toString() } };
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: agentId,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result!.id).toBe(agentId);
|
||||
expect(result!.name).toBe('Test Agent');
|
||||
expect(String(result!.author)).toBe(userId.toString());
|
||||
expect(result!.version).toBe(1);
|
||||
});
|
||||
|
||||
test('should return agent even when user is not author (permissions checked at route level)', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
const mockReq = { user: { id: userId.toString() } };
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: agentId,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
// With the new permission system, loadAgent returns the agent regardless of permissions
|
||||
// Permission checks are handled at the route level via middleware
|
||||
expect(result).toBeTruthy();
|
||||
expect(result!.id).toBe(agentId);
|
||||
expect(result!.name).toBe('Test Agent');
|
||||
});
|
||||
|
||||
test('should handle ephemeral agent with no MCP servers', async () => {
|
||||
const { EPHEMERAL_AGENT_ID } = Constants;
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
promptPrefix: 'Simple instructions',
|
||||
ephemeralAgent: {
|
||||
execute_code: false,
|
||||
web_search: false,
|
||||
mcp: [],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: EPHEMERAL_AGENT_ID as string,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-3.5-turbo' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
if (result) {
|
||||
expect(result.tools).toEqual([]);
|
||||
expect(result.instructions).toBe('Simple instructions');
|
||||
} else {
|
||||
expect(result).toBeFalsy();
|
||||
}
|
||||
});
|
||||
|
||||
test('should handle ephemeral agent with undefined ephemeralAgent in body', async () => {
|
||||
const { EPHEMERAL_AGENT_ID } = Constants;
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
promptPrefix: 'Basic instructions',
|
||||
},
|
||||
};
|
||||
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: EPHEMERAL_AGENT_ID as string,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
if (result) {
|
||||
expect(result.tools).toEqual([]);
|
||||
} else {
|
||||
expect(result).toBeFalsy();
|
||||
}
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
test('should handle loadAgent with malformed req object', async () => {
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: null as unknown as LoadAgentParams['req'],
|
||||
agent_id: 'agent_test',
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
test('should handle ephemeral agent with extremely large tool list', async () => {
|
||||
const { EPHEMERAL_AGENT_ID } = Constants;
|
||||
|
||||
const largeToolList = Array.from({ length: 100 }, (_, i) => `tool_${i}_mcp_server1`);
|
||||
const availableTools: Record<string, object> = {};
|
||||
for (const tool of largeToolList) {
|
||||
availableTools[tool] = {};
|
||||
}
|
||||
|
||||
// Mock getMCPServerTools to return all tools for server1
|
||||
mockGetMCPServerTools.mockImplementation(async (_userId: string, server: string) => {
|
||||
if (server === 'server1') {
|
||||
return availableTools; // All 100 tools belong to server1
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
promptPrefix: 'Test',
|
||||
ephemeralAgent: {
|
||||
execute_code: true,
|
||||
web_search: true,
|
||||
mcp: ['server1'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: EPHEMERAL_AGENT_ID as string,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
if (result) {
|
||||
expect(result.tools!.length).toBeGreaterThan(100);
|
||||
}
|
||||
});
|
||||
|
||||
test('should return agent from different project (permissions checked at route level)', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Project Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
const mockReq = { user: { id: userId.toString() } };
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: agentId,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
// With the new permission system, loadAgent returns the agent regardless of permissions
|
||||
// Permission checks are handled at the route level via middleware
|
||||
expect(result).toBeTruthy();
|
||||
expect(result!.id).toBe(agentId);
|
||||
expect(result!.name).toBe('Project Agent');
|
||||
});
|
||||
|
||||
test('should handle loadEphemeralAgent with malformed MCP tool names', async () => {
|
||||
const { EPHEMERAL_AGENT_ID } = Constants;
|
||||
|
||||
// Mock getMCPServerTools to return only tools matching the server
|
||||
mockGetMCPServerTools.mockImplementation(async (_userId: string, server: string) => {
|
||||
if (server === 'server1') {
|
||||
// Only return tool that correctly matches server1 format
|
||||
return { tool_mcp_server1: {} };
|
||||
} else if (server === 'server2') {
|
||||
return { tool_mcp_server2: {} };
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
promptPrefix: 'Test instructions',
|
||||
ephemeralAgent: {
|
||||
execute_code: false,
|
||||
web_search: false,
|
||||
mcp: ['server1'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: EPHEMERAL_AGENT_ID as string,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
if (result) {
|
||||
expect(result.tools).toEqual(['tool_mcp_server1']);
|
||||
expect(result.tools).not.toContain('malformed_tool_name');
|
||||
expect(result.tools).not.toContain('tool__server1');
|
||||
expect(result.tools).not.toContain('tool_mcp_server2');
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
226
packages/api/src/agents/added.ts
Normal file
226
packages/api/src/agents/added.ts
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type { AppConfig } from '@librechat/data-schemas';
|
||||
import {
|
||||
Tools,
|
||||
Constants,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
appendAgentIdSuffix,
|
||||
encodeEphemeralAgentId,
|
||||
} from 'librechat-data-provider';
|
||||
import type { Agent, TConversation } from 'librechat-data-provider';
|
||||
import { getCustomEndpointConfig } from '~/app/config';
|
||||
|
||||
const { mcp_all, mcp_delimiter } = Constants;
|
||||
|
||||
export const ADDED_AGENT_ID = 'added_agent';
|
||||
|
||||
export interface LoadAddedAgentDeps {
|
||||
getAgent: (searchParameter: { id: string }) => Promise<Agent | null>;
|
||||
getMCPServerTools: (
|
||||
userId: string,
|
||||
serverName: string,
|
||||
) => Promise<Record<string, unknown> | null>;
|
||||
}
|
||||
|
||||
interface LoadAddedAgentParams {
|
||||
req: { user?: { id?: string }; config?: Record<string, unknown> };
|
||||
conversation: TConversation | null;
|
||||
primaryAgent?: Agent | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads an agent from an added conversation (for multi-convo parallel agent execution).
|
||||
* Returns the agent config as a plain object, or null if invalid.
|
||||
*/
|
||||
export async function loadAddedAgent(
|
||||
{ req, conversation, primaryAgent }: LoadAddedAgentParams,
|
||||
deps: LoadAddedAgentDeps,
|
||||
): Promise<Agent | null> {
|
||||
if (!conversation) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (conversation.agent_id && !isEphemeralAgentId(conversation.agent_id)) {
|
||||
const agent = await deps.getAgent({ id: conversation.agent_id });
|
||||
if (!agent) {
|
||||
logger.warn(`[loadAddedAgent] Agent ${conversation.agent_id} not found`);
|
||||
return null;
|
||||
}
|
||||
|
||||
const agentRecord = agent as Record<string, unknown>;
|
||||
const versions = agentRecord.versions as unknown[] | undefined;
|
||||
agentRecord.version = versions ? versions.length : 0;
|
||||
agent.id = appendAgentIdSuffix(agent.id, 1);
|
||||
return agent;
|
||||
}
|
||||
|
||||
const { model, endpoint, promptPrefix, spec, ...rest } = conversation as TConversation & {
|
||||
promptPrefix?: string;
|
||||
spec?: string;
|
||||
modelLabel?: string;
|
||||
ephemeralAgent?: {
|
||||
mcp?: string[];
|
||||
execute_code?: boolean;
|
||||
file_search?: boolean;
|
||||
web_search?: boolean;
|
||||
artifacts?: unknown;
|
||||
};
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
if (!endpoint || !model) {
|
||||
logger.warn('[loadAddedAgent] Missing required endpoint or model for ephemeral agent');
|
||||
return null;
|
||||
}
|
||||
|
||||
const appConfig = req.config as AppConfig | undefined;
|
||||
|
||||
const primaryIsEphemeral = primaryAgent && isEphemeralAgentId(primaryAgent.id);
|
||||
if (primaryIsEphemeral && Array.isArray(primaryAgent.tools)) {
|
||||
let endpointConfig = (appConfig?.endpoints as Record<string, unknown> | undefined)?.[
|
||||
endpoint
|
||||
] as Record<string, unknown> | undefined;
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig }) as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
} catch (err) {
|
||||
logger.error('[loadAddedAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
const modelSpecs = (appConfig?.modelSpecs as { list?: Array<{ name: string; label?: string }> })
|
||||
?.list;
|
||||
const modelSpec = spec != null && spec !== '' ? modelSpecs?.find((s) => s.name === spec) : null;
|
||||
const sender =
|
||||
rest.modelLabel ??
|
||||
modelSpec?.label ??
|
||||
(endpointConfig?.modelDisplayLabel as string | undefined) ??
|
||||
'';
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender, index: 1 });
|
||||
|
||||
return {
|
||||
id: ephemeralId,
|
||||
instructions: promptPrefix || '',
|
||||
provider: endpoint,
|
||||
model_parameters: {},
|
||||
model,
|
||||
tools: [...primaryAgent.tools],
|
||||
} as unknown as Agent;
|
||||
}
|
||||
|
||||
const ephemeralAgent = rest.ephemeralAgent as
|
||||
| {
|
||||
mcp?: string[];
|
||||
execute_code?: boolean;
|
||||
file_search?: boolean;
|
||||
web_search?: boolean;
|
||||
artifacts?: unknown;
|
||||
}
|
||||
| undefined;
|
||||
const mcpServers = new Set<string>(ephemeralAgent?.mcp);
|
||||
const userId = req.user?.id ?? '';
|
||||
|
||||
const modelSpecs = (
|
||||
appConfig?.modelSpecs as {
|
||||
list?: Array<{
|
||||
name: string;
|
||||
label?: string;
|
||||
mcpServers?: string[];
|
||||
executeCode?: boolean;
|
||||
fileSearch?: boolean;
|
||||
webSearch?: boolean;
|
||||
}>;
|
||||
}
|
||||
)?.list;
|
||||
let modelSpec: (typeof modelSpecs extends Array<infer T> | undefined ? T : never) | null = null;
|
||||
if (spec != null && spec !== '') {
|
||||
modelSpec = modelSpecs?.find((s) => s.name === spec) ?? null;
|
||||
}
|
||||
if (modelSpec?.mcpServers) {
|
||||
for (const mcpServer of modelSpec.mcpServers) {
|
||||
mcpServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
|
||||
const tools: string[] = [];
|
||||
if (ephemeralAgent?.execute_code === true || modelSpec?.executeCode === true) {
|
||||
tools.push(Tools.execute_code);
|
||||
}
|
||||
if (ephemeralAgent?.file_search === true || modelSpec?.fileSearch === true) {
|
||||
tools.push(Tools.file_search);
|
||||
}
|
||||
if (ephemeralAgent?.web_search === true || modelSpec?.webSearch === true) {
|
||||
tools.push(Tools.web_search);
|
||||
}
|
||||
|
||||
const addedServers = new Set<string>();
|
||||
for (const mcpServer of mcpServers) {
|
||||
if (addedServers.has(mcpServer)) {
|
||||
continue;
|
||||
}
|
||||
const serverTools = await deps.getMCPServerTools(userId, mcpServer);
|
||||
if (!serverTools) {
|
||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||
addedServers.add(mcpServer);
|
||||
continue;
|
||||
}
|
||||
tools.push(...Object.keys(serverTools));
|
||||
addedServers.add(mcpServer);
|
||||
}
|
||||
|
||||
const model_parameters: Record<string, unknown> = {};
|
||||
const paramKeys = [
|
||||
'temperature',
|
||||
'top_p',
|
||||
'topP',
|
||||
'topK',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'maxOutputTokens',
|
||||
'maxTokens',
|
||||
'max_tokens',
|
||||
];
|
||||
for (const key of paramKeys) {
|
||||
if ((rest as Record<string, unknown>)[key] != null) {
|
||||
model_parameters[key] = (rest as Record<string, unknown>)[key];
|
||||
}
|
||||
}
|
||||
|
||||
let endpointConfig = (appConfig?.endpoints as Record<string, unknown> | undefined)?.[endpoint] as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig }) as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
} catch (err) {
|
||||
logger.error('[loadAddedAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
const sender =
|
||||
rest.modelLabel ??
|
||||
modelSpec?.label ??
|
||||
(endpointConfig?.modelDisplayLabel as string | undefined) ??
|
||||
'';
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender, index: 1 });
|
||||
|
||||
const result: Record<string, unknown> = {
|
||||
id: ephemeralId,
|
||||
instructions: promptPrefix || '',
|
||||
provider: endpoint,
|
||||
model_parameters,
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
|
||||
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
|
||||
result.artifacts = ephemeralAgent.artifacts;
|
||||
}
|
||||
|
||||
return result as unknown as Agent;
|
||||
}
|
||||
|
|
@ -15,3 +15,5 @@ export * from './responses';
|
|||
export * from './run';
|
||||
export * from './tools';
|
||||
export * from './validation';
|
||||
export * from './added';
|
||||
export * from './load';
|
||||
|
|
|
|||
162
packages/api/src/agents/load.ts
Normal file
162
packages/api/src/agents/load.ts
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type { AppConfig } from '@librechat/data-schemas';
|
||||
import {
|
||||
Tools,
|
||||
Constants,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
encodeEphemeralAgentId,
|
||||
} from 'librechat-data-provider';
|
||||
import type {
|
||||
AgentModelParameters,
|
||||
TEphemeralAgent,
|
||||
TModelSpec,
|
||||
Agent,
|
||||
} from 'librechat-data-provider';
|
||||
import { getCustomEndpointConfig } from '~/app/config';
|
||||
|
||||
const { mcp_all, mcp_delimiter } = Constants;
|
||||
|
||||
export interface LoadAgentDeps {
|
||||
getAgent: (searchParameter: { id: string }) => Promise<Agent | null>;
|
||||
getMCPServerTools: (
|
||||
userId: string,
|
||||
serverName: string,
|
||||
) => Promise<Record<string, unknown> | null>;
|
||||
}
|
||||
|
||||
export interface LoadAgentParams {
|
||||
req: {
|
||||
user?: { id?: string };
|
||||
config?: AppConfig;
|
||||
body?: {
|
||||
promptPrefix?: string;
|
||||
ephemeralAgent?: TEphemeralAgent;
|
||||
};
|
||||
};
|
||||
spec?: string;
|
||||
agent_id: string;
|
||||
endpoint: string;
|
||||
model_parameters?: AgentModelParameters & { model?: string };
|
||||
}
|
||||
|
||||
/**
|
||||
* Load an ephemeral agent based on the request parameters.
|
||||
*/
|
||||
export async function loadEphemeralAgent(
|
||||
{ req, spec, endpoint, model_parameters: _m }: Omit<LoadAgentParams, 'agent_id'>,
|
||||
deps: LoadAgentDeps,
|
||||
): Promise<Agent | null> {
|
||||
const { model, ...model_parameters } = _m ?? ({} as unknown as AgentModelParameters);
|
||||
const modelSpecs = req.config?.modelSpecs as { list?: TModelSpec[] } | undefined;
|
||||
let modelSpec: TModelSpec | null = null;
|
||||
if (spec != null && spec !== '') {
|
||||
modelSpec = modelSpecs?.list?.find((s) => s.name === spec) ?? null;
|
||||
}
|
||||
const ephemeralAgent: TEphemeralAgent | undefined = req.body?.ephemeralAgent;
|
||||
const mcpServers = new Set<string>(ephemeralAgent?.mcp);
|
||||
const userId = req.user?.id ?? '';
|
||||
if (modelSpec?.mcpServers) {
|
||||
for (const mcpServer of modelSpec.mcpServers) {
|
||||
mcpServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
const tools: string[] = [];
|
||||
if (ephemeralAgent?.execute_code === true || modelSpec?.executeCode === true) {
|
||||
tools.push(Tools.execute_code);
|
||||
}
|
||||
if (ephemeralAgent?.file_search === true || modelSpec?.fileSearch === true) {
|
||||
tools.push(Tools.file_search);
|
||||
}
|
||||
if (ephemeralAgent?.web_search === true || modelSpec?.webSearch === true) {
|
||||
tools.push(Tools.web_search);
|
||||
}
|
||||
|
||||
const addedServers = new Set<string>();
|
||||
if (mcpServers.size > 0) {
|
||||
for (const mcpServer of mcpServers) {
|
||||
if (addedServers.has(mcpServer)) {
|
||||
continue;
|
||||
}
|
||||
const serverTools = await deps.getMCPServerTools(userId, mcpServer);
|
||||
if (!serverTools) {
|
||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||
addedServers.add(mcpServer);
|
||||
continue;
|
||||
}
|
||||
tools.push(...Object.keys(serverTools));
|
||||
addedServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
|
||||
const instructions = req.body?.promptPrefix;
|
||||
|
||||
// Get endpoint config for modelDisplayLabel fallback
|
||||
const appConfig = req.config;
|
||||
const endpoints = appConfig?.endpoints;
|
||||
let endpointConfig = endpoints?.[endpoint as keyof typeof endpoints];
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig });
|
||||
} catch (err) {
|
||||
logger.error('[loadEphemeralAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
// For ephemeral agents, use modelLabel if provided, then model spec's label,
|
||||
// then modelDisplayLabel from endpoint config, otherwise empty string to show model name
|
||||
const sender =
|
||||
(model_parameters as AgentModelParameters & { modelLabel?: string })?.modelLabel ??
|
||||
modelSpec?.label ??
|
||||
(endpointConfig as { modelDisplayLabel?: string } | undefined)?.modelDisplayLabel ??
|
||||
'';
|
||||
|
||||
// Encode ephemeral agent ID with endpoint, model, and computed sender for display
|
||||
const ephemeralId = encodeEphemeralAgentId({
|
||||
endpoint,
|
||||
model: model as string,
|
||||
sender: sender as string,
|
||||
});
|
||||
|
||||
const result: Partial<Agent> = {
|
||||
id: ephemeralId,
|
||||
instructions,
|
||||
provider: endpoint,
|
||||
model_parameters,
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
|
||||
if (ephemeralAgent?.artifacts) {
|
||||
result.artifacts = ephemeralAgent.artifacts;
|
||||
}
|
||||
return result as Agent;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load an agent based on the provided ID.
|
||||
* For ephemeral agents, builds a synthetic agent from request parameters.
|
||||
* For persistent agents, fetches from the database.
|
||||
*/
|
||||
export async function loadAgent(
|
||||
params: LoadAgentParams,
|
||||
deps: LoadAgentDeps,
|
||||
): Promise<Agent | null> {
|
||||
const { req, spec, agent_id, endpoint, model_parameters } = params;
|
||||
if (!agent_id) {
|
||||
return null;
|
||||
}
|
||||
if (isEphemeralAgentId(agent_id)) {
|
||||
return loadEphemeralAgent({ req, spec, endpoint, model_parameters }, deps);
|
||||
}
|
||||
const agent = await deps.getAgent({ id: agent_id });
|
||||
|
||||
if (!agent) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Set version count from versions array length
|
||||
const agentWithVersion = agent as Agent & { versions?: unknown[]; version?: number };
|
||||
agentWithVersion.version = agentWithVersion.versions ? agentWithVersion.versions.length : 0;
|
||||
return agent;
|
||||
}
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
import { Types } from 'mongoose';
|
||||
import {
|
||||
ResourceType,
|
||||
PrincipalType,
|
||||
PermissionBits,
|
||||
AccessRoleIds,
|
||||
} from 'librechat-data-provider';
|
||||
import type { Types, Model } from 'mongoose';
|
||||
import type { PipelineStage, AnyBulkWriteOperation } from 'mongoose';
|
||||
|
||||
export interface Principal {
|
||||
type: string;
|
||||
|
|
@ -19,20 +20,14 @@ export interface Principal {
|
|||
}
|
||||
|
||||
export interface EnricherDependencies {
|
||||
AclEntry: Model<{
|
||||
principalType: string;
|
||||
principalId: Types.ObjectId;
|
||||
resourceType: string;
|
||||
resourceId: Types.ObjectId;
|
||||
permBits: number;
|
||||
roleId: Types.ObjectId;
|
||||
grantedBy: Types.ObjectId;
|
||||
grantedAt: Date;
|
||||
}>;
|
||||
AccessRole: Model<{
|
||||
accessRoleId: string;
|
||||
permBits: number;
|
||||
}>;
|
||||
aggregateAclEntries: (pipeline: PipelineStage[]) => Promise<Record<string, unknown>[]>;
|
||||
bulkWriteAclEntries: (
|
||||
ops: AnyBulkWriteOperation<unknown>[],
|
||||
options?: Record<string, unknown>,
|
||||
) => Promise<unknown>;
|
||||
findRoleByIdentifier: (
|
||||
accessRoleId: string,
|
||||
) => Promise<{ _id: Types.ObjectId; permBits: number } | null>;
|
||||
logger: { error: (msg: string, ...args: unknown[]) => void };
|
||||
}
|
||||
|
||||
|
|
@ -47,14 +42,12 @@ export async function enrichRemoteAgentPrincipals(
|
|||
resourceId: string | Types.ObjectId,
|
||||
principals: Principal[],
|
||||
): Promise<EnrichResult> {
|
||||
const { AclEntry } = deps;
|
||||
|
||||
const resourceObjectId =
|
||||
typeof resourceId === 'string' && /^[a-f\d]{24}$/i.test(resourceId)
|
||||
? deps.AclEntry.base.Types.ObjectId.createFromHexString(resourceId)
|
||||
? Types.ObjectId.createFromHexString(resourceId)
|
||||
: resourceId;
|
||||
|
||||
const agentOwnerEntries = await AclEntry.aggregate([
|
||||
const agentOwnerEntries = await deps.aggregateAclEntries([
|
||||
{
|
||||
$match: {
|
||||
resourceType: ResourceType.AGENT,
|
||||
|
|
@ -87,24 +80,28 @@ export async function enrichRemoteAgentPrincipals(
|
|||
continue;
|
||||
}
|
||||
|
||||
const userInfo = entry.userInfo as Record<string, unknown>;
|
||||
const principalId = entry.principalId as Types.ObjectId;
|
||||
|
||||
const alreadyIncluded = enrichedPrincipals.some(
|
||||
(p) => p.type === PrincipalType.USER && p.id === entry.principalId.toString(),
|
||||
(p) => p.type === PrincipalType.USER && p.id === principalId.toString(),
|
||||
);
|
||||
|
||||
if (!alreadyIncluded) {
|
||||
enrichedPrincipals.unshift({
|
||||
type: PrincipalType.USER,
|
||||
id: entry.userInfo._id.toString(),
|
||||
name: entry.userInfo.name || entry.userInfo.username,
|
||||
email: entry.userInfo.email,
|
||||
avatar: entry.userInfo.avatar,
|
||||
id: (userInfo._id as Types.ObjectId).toString(),
|
||||
name: (userInfo.name || userInfo.username) as string,
|
||||
email: userInfo.email as string,
|
||||
avatar: userInfo.avatar as string,
|
||||
source: 'local',
|
||||
idOnTheSource: entry.userInfo.idOnTheSource || entry.userInfo._id.toString(),
|
||||
idOnTheSource:
|
||||
(userInfo.idOnTheSource as string) || (userInfo._id as Types.ObjectId).toString(),
|
||||
accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER,
|
||||
isImplicit: true,
|
||||
});
|
||||
|
||||
entriesToBackfill.push(entry.principalId);
|
||||
entriesToBackfill.push(principalId);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -121,15 +118,15 @@ export function backfillRemoteAgentPermissions(
|
|||
return;
|
||||
}
|
||||
|
||||
const { AclEntry, AccessRole, logger } = deps;
|
||||
const { logger } = deps;
|
||||
|
||||
const resourceObjectId =
|
||||
typeof resourceId === 'string' && /^[a-f\d]{24}$/i.test(resourceId)
|
||||
? AclEntry.base.Types.ObjectId.createFromHexString(resourceId)
|
||||
? Types.ObjectId.createFromHexString(resourceId)
|
||||
: resourceId;
|
||||
|
||||
AccessRole.findOne({ accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER })
|
||||
.lean()
|
||||
deps
|
||||
.findRoleByIdentifier(AccessRoleIds.REMOTE_AGENT_OWNER)
|
||||
.then((role) => {
|
||||
if (!role) {
|
||||
logger.error('[backfillRemoteAgentPermissions] REMOTE_AGENT_OWNER role not found');
|
||||
|
|
@ -161,9 +158,9 @@ export function backfillRemoteAgentPermissions(
|
|||
},
|
||||
}));
|
||||
|
||||
return AclEntry.bulkWrite(bulkOps, { ordered: false });
|
||||
return deps.bulkWriteAclEntries(bulkOps, { ordered: false });
|
||||
})
|
||||
.catch((err) => {
|
||||
.catch((err: unknown) => {
|
||||
logger.error('[backfillRemoteAgentPermissions] Failed to backfill:', err);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,3 +2,5 @@ export * from './domain';
|
|||
export * from './openid';
|
||||
export * from './exchange';
|
||||
export * from './agent';
|
||||
export * from './password';
|
||||
export * from './invite';
|
||||
|
|
|
|||
61
packages/api/src/auth/invite.ts
Normal file
61
packages/api/src/auth/invite.ts
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import { Types } from 'mongoose';
|
||||
import { logger, hashToken, getRandomValues } from '@librechat/data-schemas';
|
||||
|
||||
export interface InviteDeps {
|
||||
createToken: (data: {
|
||||
userId: Types.ObjectId;
|
||||
email: string;
|
||||
token: string;
|
||||
createdAt: number;
|
||||
expiresIn: number;
|
||||
}) => Promise<unknown>;
|
||||
findToken: (filter: { token: string; email: string }) => Promise<unknown>;
|
||||
}
|
||||
|
||||
/** Creates a new user invite and returns the encoded token. */
|
||||
export async function createInvite(
|
||||
email: string,
|
||||
deps: InviteDeps,
|
||||
): Promise<string | { message: string }> {
|
||||
try {
|
||||
const token = await getRandomValues(32);
|
||||
const hash = await hashToken(token);
|
||||
const encodedToken = encodeURIComponent(token);
|
||||
const fakeUserId = new Types.ObjectId();
|
||||
|
||||
await deps.createToken({
|
||||
userId: fakeUserId,
|
||||
email,
|
||||
token: hash,
|
||||
createdAt: Date.now(),
|
||||
expiresIn: 604800,
|
||||
});
|
||||
|
||||
return encodedToken;
|
||||
} catch (error) {
|
||||
logger.error('[createInvite] Error creating invite', error);
|
||||
return { message: 'Error creating invite' };
|
||||
}
|
||||
}
|
||||
|
||||
/** Retrieves and validates a user invite by encoded token and email. */
|
||||
export async function getInvite(
|
||||
encodedToken: string,
|
||||
email: string,
|
||||
deps: InviteDeps,
|
||||
): Promise<unknown> {
|
||||
try {
|
||||
const token = decodeURIComponent(encodedToken);
|
||||
const hash = await hashToken(token);
|
||||
const invite = await deps.findToken({ token: hash, email });
|
||||
|
||||
if (!invite) {
|
||||
throw new Error('Invite not found or email does not match');
|
||||
}
|
||||
|
||||
return invite;
|
||||
} catch (error) {
|
||||
logger.error('[getInvite] Error getting invite:', error);
|
||||
return { error: true, message: (error as Error).message };
|
||||
}
|
||||
}
|
||||
25
packages/api/src/auth/password.ts
Normal file
25
packages/api/src/auth/password.ts
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
interface UserWithPassword {
|
||||
password?: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface ComparePasswordDeps {
|
||||
compare: (candidatePassword: string, hash: string) => Promise<boolean>;
|
||||
}
|
||||
|
||||
/** Compares a candidate password against a user's hashed password. */
|
||||
export async function comparePassword(
|
||||
user: UserWithPassword,
|
||||
candidatePassword: string,
|
||||
deps: ComparePasswordDeps,
|
||||
): Promise<boolean> {
|
||||
if (!user) {
|
||||
throw new Error('No user provided');
|
||||
}
|
||||
|
||||
if (!user.password) {
|
||||
throw new Error('No password, likely an email first registered via Social/OIDC login');
|
||||
}
|
||||
|
||||
return deps.compare(candidatePassword, user.password);
|
||||
}
|
||||
|
|
@ -364,12 +364,12 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface {
|
|||
|
||||
const parsedConfigs: Record<string, ParsedServerConfig> = {};
|
||||
const directData = directResults.data || [];
|
||||
const directServerNames = new Set(directData.map((s) => s.serverName));
|
||||
const directServerNames = new Set(directData.map((s: MCPServerDocument) => s.serverName));
|
||||
|
||||
const directParsed = await Promise.all(
|
||||
directData.map((s) => this.mapDBServerToParsedConfig(s)),
|
||||
directData.map((s: MCPServerDocument) => this.mapDBServerToParsedConfig(s)),
|
||||
);
|
||||
directData.forEach((s, i) => {
|
||||
directData.forEach((s: MCPServerDocument, i: number) => {
|
||||
parsedConfigs[s.serverName] = directParsed[i];
|
||||
});
|
||||
|
||||
|
|
@ -382,9 +382,9 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface {
|
|||
|
||||
const agentData = agentServers.data || [];
|
||||
const agentParsed = await Promise.all(
|
||||
agentData.map((s) => this.mapDBServerToParsedConfig(s)),
|
||||
agentData.map((s: MCPServerDocument) => this.mapDBServerToParsedConfig(s)),
|
||||
);
|
||||
agentData.forEach((s, i) => {
|
||||
agentData.forEach((s: MCPServerDocument, i: number) => {
|
||||
parsedConfigs[s.serverName] = { ...agentParsed[i], consumeOnly: true };
|
||||
});
|
||||
}
|
||||
|
|
@ -457,7 +457,7 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface {
|
|||
};
|
||||
|
||||
// Remove key field since it's user-provided (destructure to omit, not set to undefined)
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
|
||||
const { key: _removed, ...apiKeyWithoutKey } = result.apiKey!;
|
||||
result.apiKey = apiKeyWithoutKey;
|
||||
|
||||
|
|
@ -521,7 +521,7 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface {
|
|||
'[ServerConfigsDB.decryptConfig] Failed to decrypt apiKey.key, returning config without key',
|
||||
error,
|
||||
);
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
|
||||
const { key: _removedKey, ...apiKeyWithoutKey } = result.apiKey;
|
||||
result.apiKey = apiKeyWithoutKey;
|
||||
}
|
||||
|
|
@ -542,7 +542,7 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface {
|
|||
'[ServerConfigsDB.decryptConfig] Failed to decrypt client_secret, returning config without secret',
|
||||
error,
|
||||
);
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
|
||||
const { client_secret: _removed, ...oauthWithoutSecret } = oauthConfig;
|
||||
result = {
|
||||
...result,
|
||||
|
|
|
|||
|
|
@ -216,12 +216,12 @@ describe('access middleware', () => {
|
|||
|
||||
defaultParams.getRoleByName.mockResolvedValue(mockRole);
|
||||
|
||||
const checkObject = {};
|
||||
const checkObject = { id: 'agent123' };
|
||||
|
||||
const result = await checkAccess({
|
||||
...defaultParams,
|
||||
permissions: [Permissions.USE, Permissions.SHARE],
|
||||
bodyProps: {} as Record<Permissions, string[]>,
|
||||
bodyProps: { [Permissions.SHARE]: ['id'] } as Record<Permissions, string[]>,
|
||||
checkObject,
|
||||
});
|
||||
expect(result).toBe(true);
|
||||
|
|
@ -333,12 +333,12 @@ describe('access middleware', () => {
|
|||
} as unknown as IRole;
|
||||
|
||||
mockGetRoleByName.mockResolvedValue(mockRole);
|
||||
mockReq.body = {};
|
||||
mockReq.body = { id: 'agent123' };
|
||||
|
||||
const middleware = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE, Permissions.CREATE, Permissions.SHARE],
|
||||
bodyProps: {} as Record<Permissions, string[]>,
|
||||
bodyProps: { [Permissions.SHARE]: ['id'] } as Record<Permissions, string[]>,
|
||||
getRoleByName: mockGetRoleByName,
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import mongoose from 'mongoose';
|
|||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import { logger, balanceSchema } from '@librechat/data-schemas';
|
||||
import type { NextFunction, Request as ServerRequest, Response as ServerResponse } from 'express';
|
||||
import type { IBalance } from '@librechat/data-schemas';
|
||||
import type { IBalance, IBalanceUpdate } from '@librechat/data-schemas';
|
||||
import { createSetBalanceConfig } from './balance';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
|
|
@ -15,6 +15,16 @@ jest.mock('@librechat/data-schemas', () => ({
|
|||
let mongoServer: MongoMemoryServer;
|
||||
let Balance: mongoose.Model<IBalance>;
|
||||
|
||||
const findBalanceByUser = (userId: string) =>
|
||||
Balance.findOne({ user: userId }).lean() as Promise<IBalance | null>;
|
||||
|
||||
const upsertBalanceFields = (userId: string, fields: IBalanceUpdate) =>
|
||||
Balance.findOneAndUpdate(
|
||||
{ user: userId },
|
||||
{ $set: fields },
|
||||
{ upsert: true, new: true },
|
||||
).lean() as Promise<IBalance | null>;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
|
|
@ -64,7 +74,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -95,7 +106,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -120,7 +132,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -149,7 +162,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -178,7 +192,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = {} as ServerRequest;
|
||||
|
|
@ -219,7 +234,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -271,7 +287,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -315,7 +332,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -346,7 +364,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -392,7 +411,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -434,21 +454,20 @@ describe('createSetBalanceConfig', () => {
|
|||
},
|
||||
});
|
||||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
const res = createMockResponse();
|
||||
|
||||
// Spy on Balance.findOneAndUpdate to verify it's not called
|
||||
const updateSpy = jest.spyOn(Balance, 'findOneAndUpdate');
|
||||
const upsertSpy = jest.fn();
|
||||
const spiedMiddleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields: upsertSpy,
|
||||
});
|
||||
|
||||
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
|
||||
await spiedMiddleware(req as ServerRequest, res as ServerResponse, mockNext);
|
||||
|
||||
expect(mockNext).toHaveBeenCalled();
|
||||
expect(updateSpy).not.toHaveBeenCalled();
|
||||
expect(upsertSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should set tokenCredits for user with null tokenCredits', async () => {
|
||||
|
|
@ -470,7 +489,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -498,16 +518,12 @@ describe('createSetBalanceConfig', () => {
|
|||
});
|
||||
const dbError = new Error('Database error');
|
||||
|
||||
// Mock Balance.findOne to throw an error
|
||||
jest.spyOn(Balance, 'findOne').mockImplementationOnce((() => {
|
||||
return {
|
||||
lean: jest.fn().mockRejectedValue(dbError),
|
||||
};
|
||||
}) as unknown as mongoose.Model<IBalance>['findOne']);
|
||||
const failingFindBalance = () => Promise.reject(dbError);
|
||||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser: failingFindBalance,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -526,7 +542,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -556,7 +573,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -590,7 +608,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -635,7 +654,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
|
|||
|
|
@ -1,13 +1,20 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type {
|
||||
IBalanceUpdate,
|
||||
BalanceConfig,
|
||||
AppConfig,
|
||||
ObjectId,
|
||||
IBalance,
|
||||
IUser,
|
||||
} from '@librechat/data-schemas';
|
||||
import type { NextFunction, Request as ServerRequest, Response as ServerResponse } from 'express';
|
||||
import type { IBalance, IUser, BalanceConfig, ObjectId, AppConfig } from '@librechat/data-schemas';
|
||||
import type { Model } from 'mongoose';
|
||||
import type { BalanceUpdateFields } from '~/types';
|
||||
import { getBalanceConfig } from '~/app/config';
|
||||
|
||||
export interface BalanceMiddlewareOptions {
|
||||
getAppConfig: (options?: { role?: string; refresh?: boolean }) => Promise<AppConfig>;
|
||||
Balance: Model<IBalance>;
|
||||
findBalanceByUser: (userId: string) => Promise<IBalance | null>;
|
||||
upsertBalanceFields: (userId: string, fields: IBalanceUpdate) => Promise<IBalance | null>;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -75,7 +82,8 @@ function buildUpdateFields(
|
|||
*/
|
||||
export function createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
}: BalanceMiddlewareOptions): (
|
||||
req: ServerRequest,
|
||||
res: ServerResponse,
|
||||
|
|
@ -97,18 +105,14 @@ export function createSetBalanceConfig({
|
|||
return next();
|
||||
}
|
||||
const userId = typeof user._id === 'string' ? user._id : user._id.toString();
|
||||
const userBalanceRecord = await Balance.findOne({ user: userId }).lean();
|
||||
const userBalanceRecord = await findBalanceByUser(userId);
|
||||
const updateFields = buildUpdateFields(balanceConfig, userBalanceRecord, userId);
|
||||
|
||||
if (Object.keys(updateFields).length === 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
await Balance.findOneAndUpdate(
|
||||
{ user: userId },
|
||||
{ $set: updateFields },
|
||||
{ upsert: true, new: true },
|
||||
);
|
||||
await upsertBalanceFields(userId, updateFields);
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
|
|
|
|||
168
packages/api/src/middleware/checkBalance.ts
Normal file
168
packages/api/src/middleware/checkBalance.ts
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { ViolationTypes } from 'librechat-data-provider';
|
||||
import type { ServerRequest } from '~/types/http';
|
||||
import type { Response } from 'express';
|
||||
|
||||
type TimeUnit = 'seconds' | 'minutes' | 'hours' | 'days' | 'weeks' | 'months';
|
||||
|
||||
interface BalanceRecord {
|
||||
tokenCredits: number;
|
||||
autoRefillEnabled?: boolean;
|
||||
refillAmount?: number;
|
||||
lastRefill?: Date;
|
||||
refillIntervalValue?: number;
|
||||
refillIntervalUnit?: TimeUnit;
|
||||
}
|
||||
|
||||
interface TxData {
|
||||
user: string;
|
||||
model?: string;
|
||||
endpoint?: string;
|
||||
valueKey?: string;
|
||||
tokenType?: string;
|
||||
amount: number;
|
||||
endpointTokenConfig?: unknown;
|
||||
generations?: unknown[];
|
||||
}
|
||||
|
||||
export interface CheckBalanceDeps {
|
||||
findBalanceByUser: (user: string) => Promise<BalanceRecord | null>;
|
||||
getMultiplier: (params: Record<string, unknown>) => number;
|
||||
createAutoRefillTransaction: (
|
||||
data: Record<string, unknown>,
|
||||
) => Promise<{ balance: number } | undefined>;
|
||||
logViolation: (
|
||||
req: unknown,
|
||||
res: unknown,
|
||||
type: string,
|
||||
errorMessage: Record<string, unknown>,
|
||||
score: number,
|
||||
) => Promise<void>;
|
||||
}
|
||||
|
||||
function addIntervalToDate(date: Date, value: number, unit: TimeUnit): Date {
|
||||
const result = new Date(date);
|
||||
switch (unit) {
|
||||
case 'seconds':
|
||||
result.setSeconds(result.getSeconds() + value);
|
||||
break;
|
||||
case 'minutes':
|
||||
result.setMinutes(result.getMinutes() + value);
|
||||
break;
|
||||
case 'hours':
|
||||
result.setHours(result.getHours() + value);
|
||||
break;
|
||||
case 'days':
|
||||
result.setDate(result.getDate() + value);
|
||||
break;
|
||||
case 'weeks':
|
||||
result.setDate(result.getDate() + value * 7);
|
||||
break;
|
||||
case 'months':
|
||||
result.setMonth(result.getMonth() + value);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Checks a user's balance record and handles auto-refill if needed. */
|
||||
async function checkBalanceRecord(
|
||||
txData: TxData,
|
||||
deps: CheckBalanceDeps,
|
||||
): Promise<{ canSpend: boolean; balance: number; tokenCost: number }> {
|
||||
const { user, model, endpoint, valueKey, tokenType, amount, endpointTokenConfig } = txData;
|
||||
const multiplier = deps.getMultiplier({
|
||||
valueKey,
|
||||
tokenType,
|
||||
model,
|
||||
endpoint,
|
||||
endpointTokenConfig,
|
||||
});
|
||||
const tokenCost = amount * multiplier;
|
||||
|
||||
const record = await deps.findBalanceByUser(user);
|
||||
if (!record) {
|
||||
logger.debug('[Balance.check] No balance record found for user', { user });
|
||||
return { canSpend: false, balance: 0, tokenCost };
|
||||
}
|
||||
let balance = record.tokenCredits;
|
||||
|
||||
logger.debug('[Balance.check] Initial state', {
|
||||
user,
|
||||
model,
|
||||
endpoint,
|
||||
valueKey,
|
||||
tokenType,
|
||||
amount,
|
||||
balance,
|
||||
multiplier,
|
||||
endpointTokenConfig: !!endpointTokenConfig,
|
||||
});
|
||||
|
||||
if (
|
||||
balance - tokenCost <= 0 &&
|
||||
record.autoRefillEnabled &&
|
||||
record.refillAmount &&
|
||||
record.refillAmount > 0
|
||||
) {
|
||||
const lastRefillDate = new Date(record.lastRefill ?? 0);
|
||||
const now = new Date();
|
||||
if (
|
||||
isNaN(lastRefillDate.getTime()) ||
|
||||
now >=
|
||||
addIntervalToDate(
|
||||
lastRefillDate,
|
||||
record.refillIntervalValue ?? 0,
|
||||
record.refillIntervalUnit ?? 'days',
|
||||
)
|
||||
) {
|
||||
try {
|
||||
const result = await deps.createAutoRefillTransaction({
|
||||
user,
|
||||
tokenType: 'credits',
|
||||
context: 'autoRefill',
|
||||
rawAmount: record.refillAmount,
|
||||
});
|
||||
if (result) {
|
||||
balance = result.balance;
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[Balance.check] Failed to record transaction for auto-refill', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('[Balance.check] Token cost', { tokenCost });
|
||||
return { canSpend: balance >= tokenCost, balance, tokenCost };
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks balance for a user and logs a violation if they cannot spend.
|
||||
* Throws an error with the balance info if insufficient funds.
|
||||
*/
|
||||
export async function checkBalance(
|
||||
{ req, res, txData }: { req: ServerRequest; res: Response; txData: TxData },
|
||||
deps: CheckBalanceDeps,
|
||||
): Promise<boolean> {
|
||||
const { canSpend, balance, tokenCost } = await checkBalanceRecord(txData, deps);
|
||||
if (canSpend) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const type = ViolationTypes.TOKEN_BALANCE;
|
||||
const errorMessage: Record<string, unknown> = {
|
||||
type,
|
||||
balance,
|
||||
tokenCost,
|
||||
promptTokens: txData.amount,
|
||||
};
|
||||
|
||||
if (txData.generations && txData.generations.length > 0) {
|
||||
errorMessage.generations = txData.generations;
|
||||
}
|
||||
|
||||
await deps.logViolation(req, res, type, errorMessage, 0);
|
||||
throw new Error(JSON.stringify(errorMessage));
|
||||
}
|
||||
|
|
@ -4,3 +4,4 @@ export * from './error';
|
|||
export * from './balance';
|
||||
export * from './json';
|
||||
export * from './concurrency';
|
||||
export * from './checkBalance';
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import { escapeRegExp } from '@librechat/data-schemas';
|
||||
import { SystemCategories } from 'librechat-data-provider';
|
||||
import type { IPromptGroupDocument as IPromptGroup } from '@librechat/data-schemas';
|
||||
import type { Types } from 'mongoose';
|
||||
import type { PromptGroupsListResponse } from '~/types';
|
||||
import { escapeRegExp } from '~/utils/common';
|
||||
|
||||
/**
|
||||
* Formats prompt groups for the paginated /groups endpoint response
|
||||
|
|
|
|||
|
|
@ -48,12 +48,3 @@ export function optionalChainWithEmptyCheck(
|
|||
}
|
||||
return values[values.length - 1];
|
||||
}
|
||||
|
||||
/**
|
||||
* Escapes special characters in a string for use in a regular expression.
|
||||
* @param str - The string to escape.
|
||||
* @returns The escaped string safe for use in RegExp.
|
||||
*/
|
||||
export function escapeRegExp(str: string): string {
|
||||
return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ export * from './oidc';
|
|||
export * from './openid';
|
||||
export * from './promise';
|
||||
export * from './sanitizeTitle';
|
||||
export * from './tempChatRetention';
|
||||
export * from './text';
|
||||
export { default as Tokenizer, countTokens } from './tokenizer';
|
||||
export * from './yaml';
|
||||
|
|
|
|||
|
|
@ -1,137 +0,0 @@
|
|||
import type { AppConfig } from '@librechat/data-schemas';
|
||||
import {
|
||||
createTempChatExpirationDate,
|
||||
getTempChatRetentionHours,
|
||||
DEFAULT_RETENTION_HOURS,
|
||||
MIN_RETENTION_HOURS,
|
||||
MAX_RETENTION_HOURS,
|
||||
} from './tempChatRetention';
|
||||
|
||||
describe('tempChatRetention', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetModules();
|
||||
process.env = { ...originalEnv };
|
||||
delete process.env.TEMP_CHAT_RETENTION_HOURS;
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
describe('getTempChatRetentionHours', () => {
|
||||
it('should return default retention hours when no config or env var is set', () => {
|
||||
const result = getTempChatRetentionHours();
|
||||
expect(result).toBe(DEFAULT_RETENTION_HOURS);
|
||||
});
|
||||
|
||||
it('should use environment variable when set', () => {
|
||||
process.env.TEMP_CHAT_RETENTION_HOURS = '48';
|
||||
const result = getTempChatRetentionHours();
|
||||
expect(result).toBe(48);
|
||||
});
|
||||
|
||||
it('should use config value when set', () => {
|
||||
const config: Partial<AppConfig> = {
|
||||
interfaceConfig: {
|
||||
temporaryChatRetention: 12,
|
||||
},
|
||||
};
|
||||
const result = getTempChatRetentionHours(config?.interfaceConfig);
|
||||
expect(result).toBe(12);
|
||||
});
|
||||
|
||||
it('should prioritize config over environment variable', () => {
|
||||
process.env.TEMP_CHAT_RETENTION_HOURS = '48';
|
||||
const config: Partial<AppConfig> = {
|
||||
interfaceConfig: {
|
||||
temporaryChatRetention: 12,
|
||||
},
|
||||
};
|
||||
const result = getTempChatRetentionHours(config?.interfaceConfig);
|
||||
expect(result).toBe(12);
|
||||
});
|
||||
|
||||
it('should enforce minimum retention period', () => {
|
||||
const config: Partial<AppConfig> = {
|
||||
interfaceConfig: {
|
||||
temporaryChatRetention: 0,
|
||||
},
|
||||
};
|
||||
const result = getTempChatRetentionHours(config?.interfaceConfig);
|
||||
expect(result).toBe(MIN_RETENTION_HOURS);
|
||||
});
|
||||
|
||||
it('should enforce maximum retention period', () => {
|
||||
const config: Partial<AppConfig> = {
|
||||
interfaceConfig: {
|
||||
temporaryChatRetention: 10000,
|
||||
},
|
||||
};
|
||||
const result = getTempChatRetentionHours(config?.interfaceConfig);
|
||||
expect(result).toBe(MAX_RETENTION_HOURS);
|
||||
});
|
||||
|
||||
it('should handle invalid environment variable', () => {
|
||||
process.env.TEMP_CHAT_RETENTION_HOURS = 'invalid';
|
||||
const result = getTempChatRetentionHours();
|
||||
expect(result).toBe(DEFAULT_RETENTION_HOURS);
|
||||
});
|
||||
|
||||
it('should handle invalid config value', () => {
|
||||
const config: Partial<AppConfig> = {
|
||||
interfaceConfig: {
|
||||
temporaryChatRetention: 'invalid' as unknown as number,
|
||||
},
|
||||
};
|
||||
const result = getTempChatRetentionHours(config?.interfaceConfig);
|
||||
expect(result).toBe(DEFAULT_RETENTION_HOURS);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createTempChatExpirationDate', () => {
|
||||
it('should create expiration date with default retention period', () => {
|
||||
const beforeCall = Date.now();
|
||||
const result = createTempChatExpirationDate();
|
||||
const afterCall = Date.now();
|
||||
|
||||
const expectedMin = beforeCall + DEFAULT_RETENTION_HOURS * 60 * 60 * 1000;
|
||||
const expectedMax = afterCall + DEFAULT_RETENTION_HOURS * 60 * 60 * 1000;
|
||||
|
||||
// Result should be between expectedMin and expectedMax
|
||||
expect(result.getTime()).toBeGreaterThanOrEqual(expectedMin);
|
||||
expect(result.getTime()).toBeLessThanOrEqual(expectedMax);
|
||||
});
|
||||
|
||||
it('should create expiration date with custom retention period', () => {
|
||||
const config: Partial<AppConfig> = {
|
||||
interfaceConfig: {
|
||||
temporaryChatRetention: 12,
|
||||
},
|
||||
};
|
||||
|
||||
const beforeCall = Date.now();
|
||||
const result = createTempChatExpirationDate(config?.interfaceConfig);
|
||||
const afterCall = Date.now();
|
||||
|
||||
const expectedMin = beforeCall + 12 * 60 * 60 * 1000;
|
||||
const expectedMax = afterCall + 12 * 60 * 60 * 1000;
|
||||
|
||||
// Result should be between expectedMin and expectedMax
|
||||
expect(result.getTime()).toBeGreaterThanOrEqual(expectedMin);
|
||||
expect(result.getTime()).toBeLessThanOrEqual(expectedMax);
|
||||
});
|
||||
|
||||
it('should return a Date object', () => {
|
||||
const result = createTempChatExpirationDate();
|
||||
expect(result).toBeInstanceOf(Date);
|
||||
});
|
||||
|
||||
it('should return a future date', () => {
|
||||
const now = new Date();
|
||||
const result = createTempChatExpirationDate();
|
||||
expect(result.getTime()).toBeGreaterThan(now.getTime());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type { AppConfig } from '@librechat/data-schemas';
|
||||
|
||||
/**
|
||||
* Default retention period for temporary chats in hours
|
||||
*/
|
||||
export const DEFAULT_RETENTION_HOURS = 24 * 30; // 30 days
|
||||
|
||||
/**
|
||||
* Minimum allowed retention period in hours
|
||||
*/
|
||||
export const MIN_RETENTION_HOURS = 1;
|
||||
|
||||
/**
|
||||
* Maximum allowed retention period in hours (1 year = 8760 hours)
|
||||
*/
|
||||
export const MAX_RETENTION_HOURS = 8760;
|
||||
|
||||
/**
|
||||
* Gets the temporary chat retention period from environment variables or config
|
||||
* @param interfaceConfig - The custom configuration object
|
||||
* @returns The retention period in hours
|
||||
*/
|
||||
export function getTempChatRetentionHours(
|
||||
interfaceConfig?: AppConfig['interfaceConfig'] | null,
|
||||
): number {
|
||||
let retentionHours = DEFAULT_RETENTION_HOURS;
|
||||
|
||||
// Check environment variable first
|
||||
if (process.env.TEMP_CHAT_RETENTION_HOURS) {
|
||||
const envValue = parseInt(process.env.TEMP_CHAT_RETENTION_HOURS, 10);
|
||||
if (!isNaN(envValue)) {
|
||||
retentionHours = envValue;
|
||||
} else {
|
||||
logger.warn(
|
||||
`Invalid TEMP_CHAT_RETENTION_HOURS environment variable: ${process.env.TEMP_CHAT_RETENTION_HOURS}. Using default: ${DEFAULT_RETENTION_HOURS} hours.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check config file (takes precedence over environment variable)
|
||||
if (interfaceConfig?.temporaryChatRetention !== undefined) {
|
||||
const configValue = interfaceConfig.temporaryChatRetention;
|
||||
if (typeof configValue === 'number' && !isNaN(configValue)) {
|
||||
retentionHours = configValue;
|
||||
} else {
|
||||
logger.warn(
|
||||
`Invalid temporaryChatRetention in config: ${configValue}. Using ${retentionHours} hours.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the retention period
|
||||
if (retentionHours < MIN_RETENTION_HOURS) {
|
||||
logger.warn(
|
||||
`Temporary chat retention period ${retentionHours} is below minimum ${MIN_RETENTION_HOURS} hours. Using minimum value.`,
|
||||
);
|
||||
retentionHours = MIN_RETENTION_HOURS;
|
||||
} else if (retentionHours > MAX_RETENTION_HOURS) {
|
||||
logger.warn(
|
||||
`Temporary chat retention period ${retentionHours} exceeds maximum ${MAX_RETENTION_HOURS} hours. Using maximum value.`,
|
||||
);
|
||||
retentionHours = MAX_RETENTION_HOURS;
|
||||
}
|
||||
|
||||
return retentionHours;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an expiration date for temporary chats
|
||||
* @param interfaceConfig - The custom configuration object
|
||||
* @returns The expiration date
|
||||
*/
|
||||
export function createTempChatExpirationDate(interfaceConfig?: AppConfig['interfaceConfig']): Date {
|
||||
const retentionHours = getTempChatRetentionHours(interfaceConfig);
|
||||
return new Date(Date.now() + retentionHours * 60 * 60 * 1000);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue