mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-02-19 00:48:08 +01:00
📦 refactor: Consolidate DB models, encapsulating Mongoose usage in data-schemas (#11830)
* chore: move database model methods to /packages/data-schemas * chore: add TypeScript ESLint rule to warn on unused variables * refactor: model imports to streamline access - Consolidated model imports across various files to improve code organization and reduce redundancy. - Updated imports for models such as Assistant, Message, Conversation, and others to a unified import path. - Adjusted middleware and service files to reflect the new import structure, ensuring functionality remains intact. - Enhanced test files to align with the new import paths, maintaining test coverage and integrity. * chore: migrate database models to packages/data-schemas and refactor all direct Mongoose Model usage outside of data-schemas * test: update agent model mocks in unit tests - Added `getAgent` mock to `client.test.js` to enhance test coverage for agent-related functionality. - Removed redundant `getAgent` and `getAgents` mocks from `openai.spec.js` and `responses.unit.spec.js` to streamline test setup and reduce duplication. - Ensured consistency in agent mock implementations across test files. * fix: update types in data-schemas * refactor: enhance type definitions in transaction and spending methods - Updated type definitions in `checkBalance.ts` to use specific request and response types. - Refined `spendTokens.ts` to utilize a new `SpendTxData` interface for better clarity and type safety. - Improved transaction handling in `transaction.ts` by introducing `TransactionResult` and `TxData` interfaces, ensuring consistent data structures across methods. - Adjusted unit tests in `transaction.spec.ts` to accommodate new type definitions and enhance robustness. * refactor: streamline model imports and enhance code organization - Consolidated model imports across various controllers and services to a unified import path, improving code clarity and reducing redundancy. - Updated multiple files to reflect the new import structure, ensuring all functionalities remain intact. - Enhanced overall code organization by removing duplicate import statements and optimizing the usage of model methods. * feat: implement loadAddedAgent and refactor agent loading logic - Introduced `loadAddedAgent` function to handle loading agents from added conversations, supporting multi-convo parallel execution. - Created a new `load.ts` file to encapsulate agent loading functionalities, including `loadEphemeralAgent` and `loadAgent`. - Updated the `index.ts` file to export the new `load` module instead of the deprecated `loadAgent`. - Enhanced type definitions and improved error handling in the agent loading process. - Adjusted unit tests to reflect changes in the agent loading structure and ensure comprehensive coverage. * refactor: enhance balance handling with new update interface - Introduced `IBalanceUpdate` interface to streamline balance update operations across the codebase. - Updated `upsertBalanceFields` method signatures in `balance.ts`, `transaction.ts`, and related tests to utilize the new interface for improved type safety. - Adjusted type imports in `balance.spec.ts` to include `IBalanceUpdate`, ensuring consistency in balance management functionalities. - Enhanced overall code clarity and maintainability by refining type definitions related to balance operations. * feat: add unit tests for loadAgent functionality and enhance agent loading logic - Introduced comprehensive unit tests for the `loadAgent` function, covering various scenarios including null and empty agent IDs, loading of ephemeral agents, and permission checks. - Enhanced the `initializeClient` function by moving `getConvoFiles` to the correct position in the database method exports, ensuring proper functionality. - Improved test coverage for agent loading, including handling of non-existent agents and user permissions. * chore: reorder memory method exports for consistency - Moved `deleteAllUserMemories` to the correct position in the exported memory methods, ensuring a consistent and logical order of method exports in `memory.ts`.
This commit is contained in:
parent
a85e99ff45
commit
a6fb257bcf
182 changed files with 8548 additions and 8105 deletions
397
packages/api/src/agents/__tests__/load.spec.ts
Normal file
397
packages/api/src/agents/__tests__/load.spec.ts
Normal file
|
|
@ -0,0 +1,397 @@
|
|||
import mongoose from 'mongoose';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import { agentSchema, createMethods } from '@librechat/data-schemas';
|
||||
import type { AgentModelParameters } from 'librechat-data-provider';
|
||||
import type { LoadAgentParams, LoadAgentDeps } from '../load';
|
||||
import { loadAgent } from '../load';
|
||||
|
||||
let Agent: mongoose.Model<unknown>;
|
||||
let createAgent: ReturnType<typeof createMethods>['createAgent'];
|
||||
let getAgent: ReturnType<typeof createMethods>['getAgent'];
|
||||
|
||||
const mockGetMCPServerTools = jest.fn();
|
||||
|
||||
const deps: LoadAgentDeps = {
|
||||
getAgent: (searchParameter) => getAgent(searchParameter),
|
||||
getMCPServerTools: mockGetMCPServerTools,
|
||||
};
|
||||
|
||||
describe('loadAgent', () => {
|
||||
let mongoServer: MongoMemoryServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
const methods = createMethods(mongoose);
|
||||
createAgent = methods.createAgent;
|
||||
getAgent = methods.getAgent;
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await Agent.deleteMany({});
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
test('should return null when agent_id is not provided', async () => {
|
||||
const mockReq = { user: { id: 'user123' } };
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: null as unknown as string,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
test('should return null when agent_id is empty string', async () => {
|
||||
const mockReq = { user: { id: 'user123' } };
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: '',
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
test('should test ephemeral agent loading logic', async () => {
|
||||
const { EPHEMERAL_AGENT_ID } = Constants;
|
||||
|
||||
// Mock getMCPServerTools to return tools for each server
|
||||
mockGetMCPServerTools.mockImplementation(async (_userId: string, server: string) => {
|
||||
if (server === 'server1') {
|
||||
return { tool1_mcp_server1: {} };
|
||||
} else if (server === 'server2') {
|
||||
return { tool2_mcp_server2: {} };
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
promptPrefix: 'Test instructions',
|
||||
ephemeralAgent: {
|
||||
execute_code: true,
|
||||
web_search: true,
|
||||
mcp: ['server1', 'server2'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: EPHEMERAL_AGENT_ID as string,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4', temperature: 0.7 } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
if (result) {
|
||||
// Ephemeral agent ID is encoded with endpoint and model
|
||||
expect(result.id).toBe('openai__gpt-4');
|
||||
expect(result.instructions).toBe('Test instructions');
|
||||
expect(result.provider).toBe('openai');
|
||||
expect(result.model).toBe('gpt-4');
|
||||
expect(result.model_parameters.temperature).toBe(0.7);
|
||||
expect(result.tools).toContain('execute_code');
|
||||
expect(result.tools).toContain('web_search');
|
||||
expect(result.tools).toContain('tool1_mcp_server1');
|
||||
expect(result.tools).toContain('tool2_mcp_server2');
|
||||
} else {
|
||||
expect(result).toBeNull();
|
||||
}
|
||||
});
|
||||
|
||||
test('should return null for non-existent agent', async () => {
|
||||
const mockReq = { user: { id: 'user123' } };
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: 'agent_non_existent',
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
test('should load agent when user is the author', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userId,
|
||||
description: 'Test description',
|
||||
tools: ['web_search'],
|
||||
});
|
||||
|
||||
const mockReq = { user: { id: userId.toString() } };
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: agentId,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result!.id).toBe(agentId);
|
||||
expect(result!.name).toBe('Test Agent');
|
||||
expect(String(result!.author)).toBe(userId.toString());
|
||||
expect(result!.version).toBe(1);
|
||||
});
|
||||
|
||||
test('should return agent even when user is not author (permissions checked at route level)', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
const mockReq = { user: { id: userId.toString() } };
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: agentId,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
// With the new permission system, loadAgent returns the agent regardless of permissions
|
||||
// Permission checks are handled at the route level via middleware
|
||||
expect(result).toBeTruthy();
|
||||
expect(result!.id).toBe(agentId);
|
||||
expect(result!.name).toBe('Test Agent');
|
||||
});
|
||||
|
||||
test('should handle ephemeral agent with no MCP servers', async () => {
|
||||
const { EPHEMERAL_AGENT_ID } = Constants;
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
promptPrefix: 'Simple instructions',
|
||||
ephemeralAgent: {
|
||||
execute_code: false,
|
||||
web_search: false,
|
||||
mcp: [],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: EPHEMERAL_AGENT_ID as string,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-3.5-turbo' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
if (result) {
|
||||
expect(result.tools).toEqual([]);
|
||||
expect(result.instructions).toBe('Simple instructions');
|
||||
} else {
|
||||
expect(result).toBeFalsy();
|
||||
}
|
||||
});
|
||||
|
||||
test('should handle ephemeral agent with undefined ephemeralAgent in body', async () => {
|
||||
const { EPHEMERAL_AGENT_ID } = Constants;
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
promptPrefix: 'Basic instructions',
|
||||
},
|
||||
};
|
||||
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: EPHEMERAL_AGENT_ID as string,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
if (result) {
|
||||
expect(result.tools).toEqual([]);
|
||||
} else {
|
||||
expect(result).toBeFalsy();
|
||||
}
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
test('should handle loadAgent with malformed req object', async () => {
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: null as unknown as LoadAgentParams['req'],
|
||||
agent_id: 'agent_test',
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
test('should handle ephemeral agent with extremely large tool list', async () => {
|
||||
const { EPHEMERAL_AGENT_ID } = Constants;
|
||||
|
||||
const largeToolList = Array.from({ length: 100 }, (_, i) => `tool_${i}_mcp_server1`);
|
||||
const availableTools: Record<string, object> = {};
|
||||
for (const tool of largeToolList) {
|
||||
availableTools[tool] = {};
|
||||
}
|
||||
|
||||
// Mock getMCPServerTools to return all tools for server1
|
||||
mockGetMCPServerTools.mockImplementation(async (_userId: string, server: string) => {
|
||||
if (server === 'server1') {
|
||||
return availableTools; // All 100 tools belong to server1
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
promptPrefix: 'Test',
|
||||
ephemeralAgent: {
|
||||
execute_code: true,
|
||||
web_search: true,
|
||||
mcp: ['server1'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: EPHEMERAL_AGENT_ID as string,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
if (result) {
|
||||
expect(result.tools!.length).toBeGreaterThan(100);
|
||||
}
|
||||
});
|
||||
|
||||
test('should return agent from different project (permissions checked at route level)', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Project Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
const mockReq = { user: { id: userId.toString() } };
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: agentId,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
// With the new permission system, loadAgent returns the agent regardless of permissions
|
||||
// Permission checks are handled at the route level via middleware
|
||||
expect(result).toBeTruthy();
|
||||
expect(result!.id).toBe(agentId);
|
||||
expect(result!.name).toBe('Project Agent');
|
||||
});
|
||||
|
||||
test('should handle loadEphemeralAgent with malformed MCP tool names', async () => {
|
||||
const { EPHEMERAL_AGENT_ID } = Constants;
|
||||
|
||||
// Mock getMCPServerTools to return only tools matching the server
|
||||
mockGetMCPServerTools.mockImplementation(async (_userId: string, server: string) => {
|
||||
if (server === 'server1') {
|
||||
// Only return tool that correctly matches server1 format
|
||||
return { tool_mcp_server1: {} };
|
||||
} else if (server === 'server2') {
|
||||
return { tool_mcp_server2: {} };
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
promptPrefix: 'Test instructions',
|
||||
ephemeralAgent: {
|
||||
execute_code: false,
|
||||
web_search: false,
|
||||
mcp: ['server1'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await loadAgent(
|
||||
{
|
||||
req: mockReq,
|
||||
agent_id: EPHEMERAL_AGENT_ID as string,
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' } as unknown as AgentModelParameters,
|
||||
},
|
||||
deps,
|
||||
);
|
||||
|
||||
if (result) {
|
||||
expect(result.tools).toEqual(['tool_mcp_server1']);
|
||||
expect(result.tools).not.toContain('malformed_tool_name');
|
||||
expect(result.tools).not.toContain('tool__server1');
|
||||
expect(result.tools).not.toContain('tool_mcp_server2');
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
226
packages/api/src/agents/added.ts
Normal file
226
packages/api/src/agents/added.ts
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type { AppConfig } from '@librechat/data-schemas';
|
||||
import {
|
||||
Tools,
|
||||
Constants,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
appendAgentIdSuffix,
|
||||
encodeEphemeralAgentId,
|
||||
} from 'librechat-data-provider';
|
||||
import type { Agent, TConversation } from 'librechat-data-provider';
|
||||
import { getCustomEndpointConfig } from '~/app/config';
|
||||
|
||||
const { mcp_all, mcp_delimiter } = Constants;
|
||||
|
||||
export const ADDED_AGENT_ID = 'added_agent';
|
||||
|
||||
export interface LoadAddedAgentDeps {
|
||||
getAgent: (searchParameter: { id: string }) => Promise<Agent | null>;
|
||||
getMCPServerTools: (
|
||||
userId: string,
|
||||
serverName: string,
|
||||
) => Promise<Record<string, unknown> | null>;
|
||||
}
|
||||
|
||||
interface LoadAddedAgentParams {
|
||||
req: { user?: { id?: string }; config?: Record<string, unknown> };
|
||||
conversation: TConversation | null;
|
||||
primaryAgent?: Agent | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads an agent from an added conversation (for multi-convo parallel agent execution).
|
||||
* Returns the agent config as a plain object, or null if invalid.
|
||||
*/
|
||||
export async function loadAddedAgent(
|
||||
{ req, conversation, primaryAgent }: LoadAddedAgentParams,
|
||||
deps: LoadAddedAgentDeps,
|
||||
): Promise<Agent | null> {
|
||||
if (!conversation) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (conversation.agent_id && !isEphemeralAgentId(conversation.agent_id)) {
|
||||
const agent = await deps.getAgent({ id: conversation.agent_id });
|
||||
if (!agent) {
|
||||
logger.warn(`[loadAddedAgent] Agent ${conversation.agent_id} not found`);
|
||||
return null;
|
||||
}
|
||||
|
||||
const agentRecord = agent as Record<string, unknown>;
|
||||
const versions = agentRecord.versions as unknown[] | undefined;
|
||||
agentRecord.version = versions ? versions.length : 0;
|
||||
agent.id = appendAgentIdSuffix(agent.id, 1);
|
||||
return agent;
|
||||
}
|
||||
|
||||
const { model, endpoint, promptPrefix, spec, ...rest } = conversation as TConversation & {
|
||||
promptPrefix?: string;
|
||||
spec?: string;
|
||||
modelLabel?: string;
|
||||
ephemeralAgent?: {
|
||||
mcp?: string[];
|
||||
execute_code?: boolean;
|
||||
file_search?: boolean;
|
||||
web_search?: boolean;
|
||||
artifacts?: unknown;
|
||||
};
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
if (!endpoint || !model) {
|
||||
logger.warn('[loadAddedAgent] Missing required endpoint or model for ephemeral agent');
|
||||
return null;
|
||||
}
|
||||
|
||||
const appConfig = req.config as AppConfig | undefined;
|
||||
|
||||
const primaryIsEphemeral = primaryAgent && isEphemeralAgentId(primaryAgent.id);
|
||||
if (primaryIsEphemeral && Array.isArray(primaryAgent.tools)) {
|
||||
let endpointConfig = (appConfig?.endpoints as Record<string, unknown> | undefined)?.[
|
||||
endpoint
|
||||
] as Record<string, unknown> | undefined;
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig }) as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
} catch (err) {
|
||||
logger.error('[loadAddedAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
const modelSpecs = (appConfig?.modelSpecs as { list?: Array<{ name: string; label?: string }> })
|
||||
?.list;
|
||||
const modelSpec = spec != null && spec !== '' ? modelSpecs?.find((s) => s.name === spec) : null;
|
||||
const sender =
|
||||
rest.modelLabel ??
|
||||
modelSpec?.label ??
|
||||
(endpointConfig?.modelDisplayLabel as string | undefined) ??
|
||||
'';
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender, index: 1 });
|
||||
|
||||
return {
|
||||
id: ephemeralId,
|
||||
instructions: promptPrefix || '',
|
||||
provider: endpoint,
|
||||
model_parameters: {},
|
||||
model,
|
||||
tools: [...primaryAgent.tools],
|
||||
} as unknown as Agent;
|
||||
}
|
||||
|
||||
const ephemeralAgent = rest.ephemeralAgent as
|
||||
| {
|
||||
mcp?: string[];
|
||||
execute_code?: boolean;
|
||||
file_search?: boolean;
|
||||
web_search?: boolean;
|
||||
artifacts?: unknown;
|
||||
}
|
||||
| undefined;
|
||||
const mcpServers = new Set<string>(ephemeralAgent?.mcp);
|
||||
const userId = req.user?.id ?? '';
|
||||
|
||||
const modelSpecs = (
|
||||
appConfig?.modelSpecs as {
|
||||
list?: Array<{
|
||||
name: string;
|
||||
label?: string;
|
||||
mcpServers?: string[];
|
||||
executeCode?: boolean;
|
||||
fileSearch?: boolean;
|
||||
webSearch?: boolean;
|
||||
}>;
|
||||
}
|
||||
)?.list;
|
||||
let modelSpec: (typeof modelSpecs extends Array<infer T> | undefined ? T : never) | null = null;
|
||||
if (spec != null && spec !== '') {
|
||||
modelSpec = modelSpecs?.find((s) => s.name === spec) ?? null;
|
||||
}
|
||||
if (modelSpec?.mcpServers) {
|
||||
for (const mcpServer of modelSpec.mcpServers) {
|
||||
mcpServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
|
||||
const tools: string[] = [];
|
||||
if (ephemeralAgent?.execute_code === true || modelSpec?.executeCode === true) {
|
||||
tools.push(Tools.execute_code);
|
||||
}
|
||||
if (ephemeralAgent?.file_search === true || modelSpec?.fileSearch === true) {
|
||||
tools.push(Tools.file_search);
|
||||
}
|
||||
if (ephemeralAgent?.web_search === true || modelSpec?.webSearch === true) {
|
||||
tools.push(Tools.web_search);
|
||||
}
|
||||
|
||||
const addedServers = new Set<string>();
|
||||
for (const mcpServer of mcpServers) {
|
||||
if (addedServers.has(mcpServer)) {
|
||||
continue;
|
||||
}
|
||||
const serverTools = await deps.getMCPServerTools(userId, mcpServer);
|
||||
if (!serverTools) {
|
||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||
addedServers.add(mcpServer);
|
||||
continue;
|
||||
}
|
||||
tools.push(...Object.keys(serverTools));
|
||||
addedServers.add(mcpServer);
|
||||
}
|
||||
|
||||
const model_parameters: Record<string, unknown> = {};
|
||||
const paramKeys = [
|
||||
'temperature',
|
||||
'top_p',
|
||||
'topP',
|
||||
'topK',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'maxOutputTokens',
|
||||
'maxTokens',
|
||||
'max_tokens',
|
||||
];
|
||||
for (const key of paramKeys) {
|
||||
if ((rest as Record<string, unknown>)[key] != null) {
|
||||
model_parameters[key] = (rest as Record<string, unknown>)[key];
|
||||
}
|
||||
}
|
||||
|
||||
let endpointConfig = (appConfig?.endpoints as Record<string, unknown> | undefined)?.[endpoint] as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig }) as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
} catch (err) {
|
||||
logger.error('[loadAddedAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
const sender =
|
||||
rest.modelLabel ??
|
||||
modelSpec?.label ??
|
||||
(endpointConfig?.modelDisplayLabel as string | undefined) ??
|
||||
'';
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender, index: 1 });
|
||||
|
||||
const result: Record<string, unknown> = {
|
||||
id: ephemeralId,
|
||||
instructions: promptPrefix || '',
|
||||
provider: endpoint,
|
||||
model_parameters,
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
|
||||
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
|
||||
result.artifacts = ephemeralAgent.artifacts;
|
||||
}
|
||||
|
||||
return result as unknown as Agent;
|
||||
}
|
||||
|
|
@ -15,3 +15,5 @@ export * from './responses';
|
|||
export * from './run';
|
||||
export * from './tools';
|
||||
export * from './validation';
|
||||
export * from './added';
|
||||
export * from './load';
|
||||
|
|
|
|||
162
packages/api/src/agents/load.ts
Normal file
162
packages/api/src/agents/load.ts
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type { AppConfig } from '@librechat/data-schemas';
|
||||
import {
|
||||
Tools,
|
||||
Constants,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
encodeEphemeralAgentId,
|
||||
} from 'librechat-data-provider';
|
||||
import type {
|
||||
AgentModelParameters,
|
||||
TEphemeralAgent,
|
||||
TModelSpec,
|
||||
Agent,
|
||||
} from 'librechat-data-provider';
|
||||
import { getCustomEndpointConfig } from '~/app/config';
|
||||
|
||||
const { mcp_all, mcp_delimiter } = Constants;
|
||||
|
||||
export interface LoadAgentDeps {
|
||||
getAgent: (searchParameter: { id: string }) => Promise<Agent | null>;
|
||||
getMCPServerTools: (
|
||||
userId: string,
|
||||
serverName: string,
|
||||
) => Promise<Record<string, unknown> | null>;
|
||||
}
|
||||
|
||||
export interface LoadAgentParams {
|
||||
req: {
|
||||
user?: { id?: string };
|
||||
config?: AppConfig;
|
||||
body?: {
|
||||
promptPrefix?: string;
|
||||
ephemeralAgent?: TEphemeralAgent;
|
||||
};
|
||||
};
|
||||
spec?: string;
|
||||
agent_id: string;
|
||||
endpoint: string;
|
||||
model_parameters?: AgentModelParameters & { model?: string };
|
||||
}
|
||||
|
||||
/**
|
||||
* Load an ephemeral agent based on the request parameters.
|
||||
*/
|
||||
export async function loadEphemeralAgent(
|
||||
{ req, spec, endpoint, model_parameters: _m }: Omit<LoadAgentParams, 'agent_id'>,
|
||||
deps: LoadAgentDeps,
|
||||
): Promise<Agent | null> {
|
||||
const { model, ...model_parameters } = _m ?? ({} as unknown as AgentModelParameters);
|
||||
const modelSpecs = req.config?.modelSpecs as { list?: TModelSpec[] } | undefined;
|
||||
let modelSpec: TModelSpec | null = null;
|
||||
if (spec != null && spec !== '') {
|
||||
modelSpec = modelSpecs?.list?.find((s) => s.name === spec) ?? null;
|
||||
}
|
||||
const ephemeralAgent: TEphemeralAgent | undefined = req.body?.ephemeralAgent;
|
||||
const mcpServers = new Set<string>(ephemeralAgent?.mcp);
|
||||
const userId = req.user?.id ?? '';
|
||||
if (modelSpec?.mcpServers) {
|
||||
for (const mcpServer of modelSpec.mcpServers) {
|
||||
mcpServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
const tools: string[] = [];
|
||||
if (ephemeralAgent?.execute_code === true || modelSpec?.executeCode === true) {
|
||||
tools.push(Tools.execute_code);
|
||||
}
|
||||
if (ephemeralAgent?.file_search === true || modelSpec?.fileSearch === true) {
|
||||
tools.push(Tools.file_search);
|
||||
}
|
||||
if (ephemeralAgent?.web_search === true || modelSpec?.webSearch === true) {
|
||||
tools.push(Tools.web_search);
|
||||
}
|
||||
|
||||
const addedServers = new Set<string>();
|
||||
if (mcpServers.size > 0) {
|
||||
for (const mcpServer of mcpServers) {
|
||||
if (addedServers.has(mcpServer)) {
|
||||
continue;
|
||||
}
|
||||
const serverTools = await deps.getMCPServerTools(userId, mcpServer);
|
||||
if (!serverTools) {
|
||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||
addedServers.add(mcpServer);
|
||||
continue;
|
||||
}
|
||||
tools.push(...Object.keys(serverTools));
|
||||
addedServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
|
||||
const instructions = req.body?.promptPrefix;
|
||||
|
||||
// Get endpoint config for modelDisplayLabel fallback
|
||||
const appConfig = req.config;
|
||||
const endpoints = appConfig?.endpoints;
|
||||
let endpointConfig = endpoints?.[endpoint as keyof typeof endpoints];
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig });
|
||||
} catch (err) {
|
||||
logger.error('[loadEphemeralAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
// For ephemeral agents, use modelLabel if provided, then model spec's label,
|
||||
// then modelDisplayLabel from endpoint config, otherwise empty string to show model name
|
||||
const sender =
|
||||
(model_parameters as AgentModelParameters & { modelLabel?: string })?.modelLabel ??
|
||||
modelSpec?.label ??
|
||||
(endpointConfig as { modelDisplayLabel?: string } | undefined)?.modelDisplayLabel ??
|
||||
'';
|
||||
|
||||
// Encode ephemeral agent ID with endpoint, model, and computed sender for display
|
||||
const ephemeralId = encodeEphemeralAgentId({
|
||||
endpoint,
|
||||
model: model as string,
|
||||
sender: sender as string,
|
||||
});
|
||||
|
||||
const result: Partial<Agent> = {
|
||||
id: ephemeralId,
|
||||
instructions,
|
||||
provider: endpoint,
|
||||
model_parameters,
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
|
||||
if (ephemeralAgent?.artifacts) {
|
||||
result.artifacts = ephemeralAgent.artifacts;
|
||||
}
|
||||
return result as Agent;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load an agent based on the provided ID.
|
||||
* For ephemeral agents, builds a synthetic agent from request parameters.
|
||||
* For persistent agents, fetches from the database.
|
||||
*/
|
||||
export async function loadAgent(
|
||||
params: LoadAgentParams,
|
||||
deps: LoadAgentDeps,
|
||||
): Promise<Agent | null> {
|
||||
const { req, spec, agent_id, endpoint, model_parameters } = params;
|
||||
if (!agent_id) {
|
||||
return null;
|
||||
}
|
||||
if (isEphemeralAgentId(agent_id)) {
|
||||
return loadEphemeralAgent({ req, spec, endpoint, model_parameters }, deps);
|
||||
}
|
||||
const agent = await deps.getAgent({ id: agent_id });
|
||||
|
||||
if (!agent) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Set version count from versions array length
|
||||
const agentWithVersion = agent as Agent & { versions?: unknown[]; version?: number };
|
||||
agentWithVersion.version = agentWithVersion.versions ? agentWithVersion.versions.length : 0;
|
||||
return agent;
|
||||
}
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
import { Types } from 'mongoose';
|
||||
import {
|
||||
ResourceType,
|
||||
PrincipalType,
|
||||
PermissionBits,
|
||||
AccessRoleIds,
|
||||
} from 'librechat-data-provider';
|
||||
import type { Types, Model } from 'mongoose';
|
||||
import type { PipelineStage, AnyBulkWriteOperation } from 'mongoose';
|
||||
|
||||
export interface Principal {
|
||||
type: string;
|
||||
|
|
@ -19,20 +20,14 @@ export interface Principal {
|
|||
}
|
||||
|
||||
export interface EnricherDependencies {
|
||||
AclEntry: Model<{
|
||||
principalType: string;
|
||||
principalId: Types.ObjectId;
|
||||
resourceType: string;
|
||||
resourceId: Types.ObjectId;
|
||||
permBits: number;
|
||||
roleId: Types.ObjectId;
|
||||
grantedBy: Types.ObjectId;
|
||||
grantedAt: Date;
|
||||
}>;
|
||||
AccessRole: Model<{
|
||||
accessRoleId: string;
|
||||
permBits: number;
|
||||
}>;
|
||||
aggregateAclEntries: (pipeline: PipelineStage[]) => Promise<Record<string, unknown>[]>;
|
||||
bulkWriteAclEntries: (
|
||||
ops: AnyBulkWriteOperation<unknown>[],
|
||||
options?: Record<string, unknown>,
|
||||
) => Promise<unknown>;
|
||||
findRoleByIdentifier: (
|
||||
accessRoleId: string,
|
||||
) => Promise<{ _id: Types.ObjectId; permBits: number } | null>;
|
||||
logger: { error: (msg: string, ...args: unknown[]) => void };
|
||||
}
|
||||
|
||||
|
|
@ -47,14 +42,12 @@ export async function enrichRemoteAgentPrincipals(
|
|||
resourceId: string | Types.ObjectId,
|
||||
principals: Principal[],
|
||||
): Promise<EnrichResult> {
|
||||
const { AclEntry } = deps;
|
||||
|
||||
const resourceObjectId =
|
||||
typeof resourceId === 'string' && /^[a-f\d]{24}$/i.test(resourceId)
|
||||
? deps.AclEntry.base.Types.ObjectId.createFromHexString(resourceId)
|
||||
? Types.ObjectId.createFromHexString(resourceId)
|
||||
: resourceId;
|
||||
|
||||
const agentOwnerEntries = await AclEntry.aggregate([
|
||||
const agentOwnerEntries = await deps.aggregateAclEntries([
|
||||
{
|
||||
$match: {
|
||||
resourceType: ResourceType.AGENT,
|
||||
|
|
@ -87,24 +80,28 @@ export async function enrichRemoteAgentPrincipals(
|
|||
continue;
|
||||
}
|
||||
|
||||
const userInfo = entry.userInfo as Record<string, unknown>;
|
||||
const principalId = entry.principalId as Types.ObjectId;
|
||||
|
||||
const alreadyIncluded = enrichedPrincipals.some(
|
||||
(p) => p.type === PrincipalType.USER && p.id === entry.principalId.toString(),
|
||||
(p) => p.type === PrincipalType.USER && p.id === principalId.toString(),
|
||||
);
|
||||
|
||||
if (!alreadyIncluded) {
|
||||
enrichedPrincipals.unshift({
|
||||
type: PrincipalType.USER,
|
||||
id: entry.userInfo._id.toString(),
|
||||
name: entry.userInfo.name || entry.userInfo.username,
|
||||
email: entry.userInfo.email,
|
||||
avatar: entry.userInfo.avatar,
|
||||
id: (userInfo._id as Types.ObjectId).toString(),
|
||||
name: (userInfo.name || userInfo.username) as string,
|
||||
email: userInfo.email as string,
|
||||
avatar: userInfo.avatar as string,
|
||||
source: 'local',
|
||||
idOnTheSource: entry.userInfo.idOnTheSource || entry.userInfo._id.toString(),
|
||||
idOnTheSource:
|
||||
(userInfo.idOnTheSource as string) || (userInfo._id as Types.ObjectId).toString(),
|
||||
accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER,
|
||||
isImplicit: true,
|
||||
});
|
||||
|
||||
entriesToBackfill.push(entry.principalId);
|
||||
entriesToBackfill.push(principalId);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -121,15 +118,15 @@ export function backfillRemoteAgentPermissions(
|
|||
return;
|
||||
}
|
||||
|
||||
const { AclEntry, AccessRole, logger } = deps;
|
||||
const { logger } = deps;
|
||||
|
||||
const resourceObjectId =
|
||||
typeof resourceId === 'string' && /^[a-f\d]{24}$/i.test(resourceId)
|
||||
? AclEntry.base.Types.ObjectId.createFromHexString(resourceId)
|
||||
? Types.ObjectId.createFromHexString(resourceId)
|
||||
: resourceId;
|
||||
|
||||
AccessRole.findOne({ accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER })
|
||||
.lean()
|
||||
deps
|
||||
.findRoleByIdentifier(AccessRoleIds.REMOTE_AGENT_OWNER)
|
||||
.then((role) => {
|
||||
if (!role) {
|
||||
logger.error('[backfillRemoteAgentPermissions] REMOTE_AGENT_OWNER role not found');
|
||||
|
|
@ -161,9 +158,9 @@ export function backfillRemoteAgentPermissions(
|
|||
},
|
||||
}));
|
||||
|
||||
return AclEntry.bulkWrite(bulkOps, { ordered: false });
|
||||
return deps.bulkWriteAclEntries(bulkOps, { ordered: false });
|
||||
})
|
||||
.catch((err) => {
|
||||
.catch((err: unknown) => {
|
||||
logger.error('[backfillRemoteAgentPermissions] Failed to backfill:', err);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,3 +2,5 @@ export * from './domain';
|
|||
export * from './openid';
|
||||
export * from './exchange';
|
||||
export * from './agent';
|
||||
export * from './password';
|
||||
export * from './invite';
|
||||
|
|
|
|||
61
packages/api/src/auth/invite.ts
Normal file
61
packages/api/src/auth/invite.ts
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import { Types } from 'mongoose';
|
||||
import { logger, hashToken, getRandomValues } from '@librechat/data-schemas';
|
||||
|
||||
export interface InviteDeps {
|
||||
createToken: (data: {
|
||||
userId: Types.ObjectId;
|
||||
email: string;
|
||||
token: string;
|
||||
createdAt: number;
|
||||
expiresIn: number;
|
||||
}) => Promise<unknown>;
|
||||
findToken: (filter: { token: string; email: string }) => Promise<unknown>;
|
||||
}
|
||||
|
||||
/** Creates a new user invite and returns the encoded token. */
|
||||
export async function createInvite(
|
||||
email: string,
|
||||
deps: InviteDeps,
|
||||
): Promise<string | { message: string }> {
|
||||
try {
|
||||
const token = await getRandomValues(32);
|
||||
const hash = await hashToken(token);
|
||||
const encodedToken = encodeURIComponent(token);
|
||||
const fakeUserId = new Types.ObjectId();
|
||||
|
||||
await deps.createToken({
|
||||
userId: fakeUserId,
|
||||
email,
|
||||
token: hash,
|
||||
createdAt: Date.now(),
|
||||
expiresIn: 604800,
|
||||
});
|
||||
|
||||
return encodedToken;
|
||||
} catch (error) {
|
||||
logger.error('[createInvite] Error creating invite', error);
|
||||
return { message: 'Error creating invite' };
|
||||
}
|
||||
}
|
||||
|
||||
/** Retrieves and validates a user invite by encoded token and email. */
|
||||
export async function getInvite(
|
||||
encodedToken: string,
|
||||
email: string,
|
||||
deps: InviteDeps,
|
||||
): Promise<unknown> {
|
||||
try {
|
||||
const token = decodeURIComponent(encodedToken);
|
||||
const hash = await hashToken(token);
|
||||
const invite = await deps.findToken({ token: hash, email });
|
||||
|
||||
if (!invite) {
|
||||
throw new Error('Invite not found or email does not match');
|
||||
}
|
||||
|
||||
return invite;
|
||||
} catch (error) {
|
||||
logger.error('[getInvite] Error getting invite:', error);
|
||||
return { error: true, message: (error as Error).message };
|
||||
}
|
||||
}
|
||||
25
packages/api/src/auth/password.ts
Normal file
25
packages/api/src/auth/password.ts
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
interface UserWithPassword {
|
||||
password?: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface ComparePasswordDeps {
|
||||
compare: (candidatePassword: string, hash: string) => Promise<boolean>;
|
||||
}
|
||||
|
||||
/** Compares a candidate password against a user's hashed password. */
|
||||
export async function comparePassword(
|
||||
user: UserWithPassword,
|
||||
candidatePassword: string,
|
||||
deps: ComparePasswordDeps,
|
||||
): Promise<boolean> {
|
||||
if (!user) {
|
||||
throw new Error('No user provided');
|
||||
}
|
||||
|
||||
if (!user.password) {
|
||||
throw new Error('No password, likely an email first registered via Social/OIDC login');
|
||||
}
|
||||
|
||||
return deps.compare(candidatePassword, user.password);
|
||||
}
|
||||
|
|
@ -364,12 +364,12 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface {
|
|||
|
||||
const parsedConfigs: Record<string, ParsedServerConfig> = {};
|
||||
const directData = directResults.data || [];
|
||||
const directServerNames = new Set(directData.map((s) => s.serverName));
|
||||
const directServerNames = new Set(directData.map((s: MCPServerDocument) => s.serverName));
|
||||
|
||||
const directParsed = await Promise.all(
|
||||
directData.map((s) => this.mapDBServerToParsedConfig(s)),
|
||||
directData.map((s: MCPServerDocument) => this.mapDBServerToParsedConfig(s)),
|
||||
);
|
||||
directData.forEach((s, i) => {
|
||||
directData.forEach((s: MCPServerDocument, i: number) => {
|
||||
parsedConfigs[s.serverName] = directParsed[i];
|
||||
});
|
||||
|
||||
|
|
@ -382,9 +382,9 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface {
|
|||
|
||||
const agentData = agentServers.data || [];
|
||||
const agentParsed = await Promise.all(
|
||||
agentData.map((s) => this.mapDBServerToParsedConfig(s)),
|
||||
agentData.map((s: MCPServerDocument) => this.mapDBServerToParsedConfig(s)),
|
||||
);
|
||||
agentData.forEach((s, i) => {
|
||||
agentData.forEach((s: MCPServerDocument, i: number) => {
|
||||
parsedConfigs[s.serverName] = { ...agentParsed[i], consumeOnly: true };
|
||||
});
|
||||
}
|
||||
|
|
@ -457,7 +457,7 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface {
|
|||
};
|
||||
|
||||
// Remove key field since it's user-provided (destructure to omit, not set to undefined)
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
|
||||
const { key: _removed, ...apiKeyWithoutKey } = result.apiKey!;
|
||||
result.apiKey = apiKeyWithoutKey;
|
||||
|
||||
|
|
@ -521,7 +521,7 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface {
|
|||
'[ServerConfigsDB.decryptConfig] Failed to decrypt apiKey.key, returning config without key',
|
||||
error,
|
||||
);
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
|
||||
const { key: _removedKey, ...apiKeyWithoutKey } = result.apiKey;
|
||||
result.apiKey = apiKeyWithoutKey;
|
||||
}
|
||||
|
|
@ -542,7 +542,7 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface {
|
|||
'[ServerConfigsDB.decryptConfig] Failed to decrypt client_secret, returning config without secret',
|
||||
error,
|
||||
);
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
|
||||
const { client_secret: _removed, ...oauthWithoutSecret } = oauthConfig;
|
||||
result = {
|
||||
...result,
|
||||
|
|
|
|||
|
|
@ -216,12 +216,12 @@ describe('access middleware', () => {
|
|||
|
||||
defaultParams.getRoleByName.mockResolvedValue(mockRole);
|
||||
|
||||
const checkObject = {};
|
||||
const checkObject = { id: 'agent123' };
|
||||
|
||||
const result = await checkAccess({
|
||||
...defaultParams,
|
||||
permissions: [Permissions.USE, Permissions.SHARE],
|
||||
bodyProps: {} as Record<Permissions, string[]>,
|
||||
bodyProps: { [Permissions.SHARE]: ['id'] } as Record<Permissions, string[]>,
|
||||
checkObject,
|
||||
});
|
||||
expect(result).toBe(true);
|
||||
|
|
@ -333,12 +333,12 @@ describe('access middleware', () => {
|
|||
} as unknown as IRole;
|
||||
|
||||
mockGetRoleByName.mockResolvedValue(mockRole);
|
||||
mockReq.body = {};
|
||||
mockReq.body = { id: 'agent123' };
|
||||
|
||||
const middleware = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE, Permissions.CREATE, Permissions.SHARE],
|
||||
bodyProps: {} as Record<Permissions, string[]>,
|
||||
bodyProps: { [Permissions.SHARE]: ['id'] } as Record<Permissions, string[]>,
|
||||
getRoleByName: mockGetRoleByName,
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import mongoose from 'mongoose';
|
|||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import { logger, balanceSchema } from '@librechat/data-schemas';
|
||||
import type { NextFunction, Request as ServerRequest, Response as ServerResponse } from 'express';
|
||||
import type { IBalance } from '@librechat/data-schemas';
|
||||
import type { IBalance, IBalanceUpdate } from '@librechat/data-schemas';
|
||||
import { createSetBalanceConfig } from './balance';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
|
|
@ -15,6 +15,16 @@ jest.mock('@librechat/data-schemas', () => ({
|
|||
let mongoServer: MongoMemoryServer;
|
||||
let Balance: mongoose.Model<IBalance>;
|
||||
|
||||
const findBalanceByUser = (userId: string) =>
|
||||
Balance.findOne({ user: userId }).lean() as Promise<IBalance | null>;
|
||||
|
||||
const upsertBalanceFields = (userId: string, fields: IBalanceUpdate) =>
|
||||
Balance.findOneAndUpdate(
|
||||
{ user: userId },
|
||||
{ $set: fields },
|
||||
{ upsert: true, new: true },
|
||||
).lean() as Promise<IBalance | null>;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
|
|
@ -64,7 +74,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -95,7 +106,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -120,7 +132,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -149,7 +162,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -178,7 +192,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = {} as ServerRequest;
|
||||
|
|
@ -219,7 +234,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -271,7 +287,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -315,7 +332,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -346,7 +364,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -392,7 +411,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -434,21 +454,20 @@ describe('createSetBalanceConfig', () => {
|
|||
},
|
||||
});
|
||||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
const res = createMockResponse();
|
||||
|
||||
// Spy on Balance.findOneAndUpdate to verify it's not called
|
||||
const updateSpy = jest.spyOn(Balance, 'findOneAndUpdate');
|
||||
const upsertSpy = jest.fn();
|
||||
const spiedMiddleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields: upsertSpy,
|
||||
});
|
||||
|
||||
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
|
||||
await spiedMiddleware(req as ServerRequest, res as ServerResponse, mockNext);
|
||||
|
||||
expect(mockNext).toHaveBeenCalled();
|
||||
expect(updateSpy).not.toHaveBeenCalled();
|
||||
expect(upsertSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should set tokenCredits for user with null tokenCredits', async () => {
|
||||
|
|
@ -470,7 +489,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -498,16 +518,12 @@ describe('createSetBalanceConfig', () => {
|
|||
});
|
||||
const dbError = new Error('Database error');
|
||||
|
||||
// Mock Balance.findOne to throw an error
|
||||
jest.spyOn(Balance, 'findOne').mockImplementationOnce((() => {
|
||||
return {
|
||||
lean: jest.fn().mockRejectedValue(dbError),
|
||||
};
|
||||
}) as unknown as mongoose.Model<IBalance>['findOne']);
|
||||
const failingFindBalance = () => Promise.reject(dbError);
|
||||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser: failingFindBalance,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -526,7 +542,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -556,7 +573,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -590,7 +608,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
@ -635,7 +654,8 @@ describe('createSetBalanceConfig', () => {
|
|||
|
||||
const middleware = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const req = createMockRequest(userId);
|
||||
|
|
|
|||
|
|
@ -1,13 +1,20 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type {
|
||||
IBalanceUpdate,
|
||||
BalanceConfig,
|
||||
AppConfig,
|
||||
ObjectId,
|
||||
IBalance,
|
||||
IUser,
|
||||
} from '@librechat/data-schemas';
|
||||
import type { NextFunction, Request as ServerRequest, Response as ServerResponse } from 'express';
|
||||
import type { IBalance, IUser, BalanceConfig, ObjectId, AppConfig } from '@librechat/data-schemas';
|
||||
import type { Model } from 'mongoose';
|
||||
import type { BalanceUpdateFields } from '~/types';
|
||||
import { getBalanceConfig } from '~/app/config';
|
||||
|
||||
export interface BalanceMiddlewareOptions {
|
||||
getAppConfig: (options?: { role?: string; refresh?: boolean }) => Promise<AppConfig>;
|
||||
Balance: Model<IBalance>;
|
||||
findBalanceByUser: (userId: string) => Promise<IBalance | null>;
|
||||
upsertBalanceFields: (userId: string, fields: IBalanceUpdate) => Promise<IBalance | null>;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -75,7 +82,8 @@ function buildUpdateFields(
|
|||
*/
|
||||
export function createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
}: BalanceMiddlewareOptions): (
|
||||
req: ServerRequest,
|
||||
res: ServerResponse,
|
||||
|
|
@ -97,18 +105,14 @@ export function createSetBalanceConfig({
|
|||
return next();
|
||||
}
|
||||
const userId = typeof user._id === 'string' ? user._id : user._id.toString();
|
||||
const userBalanceRecord = await Balance.findOne({ user: userId }).lean();
|
||||
const userBalanceRecord = await findBalanceByUser(userId);
|
||||
const updateFields = buildUpdateFields(balanceConfig, userBalanceRecord, userId);
|
||||
|
||||
if (Object.keys(updateFields).length === 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
await Balance.findOneAndUpdate(
|
||||
{ user: userId },
|
||||
{ $set: updateFields },
|
||||
{ upsert: true, new: true },
|
||||
);
|
||||
await upsertBalanceFields(userId, updateFields);
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
|
|
|
|||
168
packages/api/src/middleware/checkBalance.ts
Normal file
168
packages/api/src/middleware/checkBalance.ts
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { ViolationTypes } from 'librechat-data-provider';
|
||||
import type { ServerRequest } from '~/types/http';
|
||||
import type { Response } from 'express';
|
||||
|
||||
type TimeUnit = 'seconds' | 'minutes' | 'hours' | 'days' | 'weeks' | 'months';
|
||||
|
||||
interface BalanceRecord {
|
||||
tokenCredits: number;
|
||||
autoRefillEnabled?: boolean;
|
||||
refillAmount?: number;
|
||||
lastRefill?: Date;
|
||||
refillIntervalValue?: number;
|
||||
refillIntervalUnit?: TimeUnit;
|
||||
}
|
||||
|
||||
interface TxData {
|
||||
user: string;
|
||||
model?: string;
|
||||
endpoint?: string;
|
||||
valueKey?: string;
|
||||
tokenType?: string;
|
||||
amount: number;
|
||||
endpointTokenConfig?: unknown;
|
||||
generations?: unknown[];
|
||||
}
|
||||
|
||||
export interface CheckBalanceDeps {
|
||||
findBalanceByUser: (user: string) => Promise<BalanceRecord | null>;
|
||||
getMultiplier: (params: Record<string, unknown>) => number;
|
||||
createAutoRefillTransaction: (
|
||||
data: Record<string, unknown>,
|
||||
) => Promise<{ balance: number } | undefined>;
|
||||
logViolation: (
|
||||
req: unknown,
|
||||
res: unknown,
|
||||
type: string,
|
||||
errorMessage: Record<string, unknown>,
|
||||
score: number,
|
||||
) => Promise<void>;
|
||||
}
|
||||
|
||||
function addIntervalToDate(date: Date, value: number, unit: TimeUnit): Date {
|
||||
const result = new Date(date);
|
||||
switch (unit) {
|
||||
case 'seconds':
|
||||
result.setSeconds(result.getSeconds() + value);
|
||||
break;
|
||||
case 'minutes':
|
||||
result.setMinutes(result.getMinutes() + value);
|
||||
break;
|
||||
case 'hours':
|
||||
result.setHours(result.getHours() + value);
|
||||
break;
|
||||
case 'days':
|
||||
result.setDate(result.getDate() + value);
|
||||
break;
|
||||
case 'weeks':
|
||||
result.setDate(result.getDate() + value * 7);
|
||||
break;
|
||||
case 'months':
|
||||
result.setMonth(result.getMonth() + value);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Checks a user's balance record and handles auto-refill if needed. */
|
||||
async function checkBalanceRecord(
|
||||
txData: TxData,
|
||||
deps: CheckBalanceDeps,
|
||||
): Promise<{ canSpend: boolean; balance: number; tokenCost: number }> {
|
||||
const { user, model, endpoint, valueKey, tokenType, amount, endpointTokenConfig } = txData;
|
||||
const multiplier = deps.getMultiplier({
|
||||
valueKey,
|
||||
tokenType,
|
||||
model,
|
||||
endpoint,
|
||||
endpointTokenConfig,
|
||||
});
|
||||
const tokenCost = amount * multiplier;
|
||||
|
||||
const record = await deps.findBalanceByUser(user);
|
||||
if (!record) {
|
||||
logger.debug('[Balance.check] No balance record found for user', { user });
|
||||
return { canSpend: false, balance: 0, tokenCost };
|
||||
}
|
||||
let balance = record.tokenCredits;
|
||||
|
||||
logger.debug('[Balance.check] Initial state', {
|
||||
user,
|
||||
model,
|
||||
endpoint,
|
||||
valueKey,
|
||||
tokenType,
|
||||
amount,
|
||||
balance,
|
||||
multiplier,
|
||||
endpointTokenConfig: !!endpointTokenConfig,
|
||||
});
|
||||
|
||||
if (
|
||||
balance - tokenCost <= 0 &&
|
||||
record.autoRefillEnabled &&
|
||||
record.refillAmount &&
|
||||
record.refillAmount > 0
|
||||
) {
|
||||
const lastRefillDate = new Date(record.lastRefill ?? 0);
|
||||
const now = new Date();
|
||||
if (
|
||||
isNaN(lastRefillDate.getTime()) ||
|
||||
now >=
|
||||
addIntervalToDate(
|
||||
lastRefillDate,
|
||||
record.refillIntervalValue ?? 0,
|
||||
record.refillIntervalUnit ?? 'days',
|
||||
)
|
||||
) {
|
||||
try {
|
||||
const result = await deps.createAutoRefillTransaction({
|
||||
user,
|
||||
tokenType: 'credits',
|
||||
context: 'autoRefill',
|
||||
rawAmount: record.refillAmount,
|
||||
});
|
||||
if (result) {
|
||||
balance = result.balance;
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[Balance.check] Failed to record transaction for auto-refill', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('[Balance.check] Token cost', { tokenCost });
|
||||
return { canSpend: balance >= tokenCost, balance, tokenCost };
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks balance for a user and logs a violation if they cannot spend.
|
||||
* Throws an error with the balance info if insufficient funds.
|
||||
*/
|
||||
export async function checkBalance(
|
||||
{ req, res, txData }: { req: ServerRequest; res: Response; txData: TxData },
|
||||
deps: CheckBalanceDeps,
|
||||
): Promise<boolean> {
|
||||
const { canSpend, balance, tokenCost } = await checkBalanceRecord(txData, deps);
|
||||
if (canSpend) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const type = ViolationTypes.TOKEN_BALANCE;
|
||||
const errorMessage: Record<string, unknown> = {
|
||||
type,
|
||||
balance,
|
||||
tokenCost,
|
||||
promptTokens: txData.amount,
|
||||
};
|
||||
|
||||
if (txData.generations && txData.generations.length > 0) {
|
||||
errorMessage.generations = txData.generations;
|
||||
}
|
||||
|
||||
await deps.logViolation(req, res, type, errorMessage, 0);
|
||||
throw new Error(JSON.stringify(errorMessage));
|
||||
}
|
||||
|
|
@ -4,3 +4,4 @@ export * from './error';
|
|||
export * from './balance';
|
||||
export * from './json';
|
||||
export * from './concurrency';
|
||||
export * from './checkBalance';
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import { escapeRegExp } from '@librechat/data-schemas';
|
||||
import { SystemCategories } from 'librechat-data-provider';
|
||||
import type { IPromptGroupDocument as IPromptGroup } from '@librechat/data-schemas';
|
||||
import type { Types } from 'mongoose';
|
||||
import type { PromptGroupsListResponse } from '~/types';
|
||||
import { escapeRegExp } from '~/utils/common';
|
||||
|
||||
/**
|
||||
* Formats prompt groups for the paginated /groups endpoint response
|
||||
|
|
|
|||
|
|
@ -48,12 +48,3 @@ export function optionalChainWithEmptyCheck(
|
|||
}
|
||||
return values[values.length - 1];
|
||||
}
|
||||
|
||||
/**
|
||||
* Escapes special characters in a string for use in a regular expression.
|
||||
* @param str - The string to escape.
|
||||
* @returns The escaped string safe for use in RegExp.
|
||||
*/
|
||||
export function escapeRegExp(str: string): string {
|
||||
return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ export * from './oidc';
|
|||
export * from './openid';
|
||||
export * from './promise';
|
||||
export * from './sanitizeTitle';
|
||||
export * from './tempChatRetention';
|
||||
export * from './text';
|
||||
export { default as Tokenizer, countTokens } from './tokenizer';
|
||||
export * from './yaml';
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ export default {
|
|||
commonjs(),
|
||||
// Compile TypeScript files and generate type declarations
|
||||
typescript({
|
||||
tsconfig: './tsconfig.json',
|
||||
tsconfig: './tsconfig.build.json',
|
||||
declaration: true,
|
||||
declarationDir: 'dist/types',
|
||||
rootDir: 'src',
|
||||
|
|
|
|||
|
|
@ -4,7 +4,15 @@ export * from './crypto';
|
|||
export * from './schema';
|
||||
export * from './utils';
|
||||
export { createModels } from './models';
|
||||
export { createMethods, DEFAULT_REFRESH_TOKEN_EXPIRY, DEFAULT_SESSION_EXPIRY } from './methods';
|
||||
export {
|
||||
createMethods,
|
||||
DEFAULT_REFRESH_TOKEN_EXPIRY,
|
||||
DEFAULT_SESSION_EXPIRY,
|
||||
tokenValues,
|
||||
cacheTokenValues,
|
||||
premiumTokenValues,
|
||||
defaultRate,
|
||||
} from './methods';
|
||||
export type * from './types';
|
||||
export type * from './methods';
|
||||
export { default as logger } from './config/winston';
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
import { Types } from 'mongoose';
|
||||
import { PrincipalType, PrincipalModel } from 'librechat-data-provider';
|
||||
import type { Model, DeleteResult, ClientSession } from 'mongoose';
|
||||
import type {
|
||||
AnyBulkWriteOperation,
|
||||
ClientSession,
|
||||
PipelineStage,
|
||||
DeleteResult,
|
||||
Model,
|
||||
} from 'mongoose';
|
||||
import type { IAclEntry } from '~/types';
|
||||
|
||||
export function createAclEntryMethods(mongoose: typeof import('mongoose')) {
|
||||
|
|
@ -349,6 +355,58 @@ export function createAclEntryMethods(mongoose: typeof import('mongoose')) {
|
|||
return entries;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes ACL entries matching the given filter.
|
||||
* @param filter - MongoDB filter query
|
||||
* @param options - Optional query options (e.g., { session })
|
||||
*/
|
||||
async function deleteAclEntries(
|
||||
filter: Record<string, unknown>,
|
||||
options?: { session?: ClientSession },
|
||||
): Promise<DeleteResult> {
|
||||
const AclEntry = mongoose.models.AclEntry as Model<IAclEntry>;
|
||||
return AclEntry.deleteMany(filter, options || {});
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a bulk write operation on ACL entries.
|
||||
* @param ops - Array of bulk write operations
|
||||
* @param options - Optional query options (e.g., { session })
|
||||
*/
|
||||
async function bulkWriteAclEntries(
|
||||
ops: AnyBulkWriteOperation<IAclEntry>[],
|
||||
options?: { session?: ClientSession },
|
||||
) {
|
||||
const AclEntry = mongoose.models.AclEntry as Model<IAclEntry>;
|
||||
return AclEntry.bulkWrite(ops, options || {});
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds all publicly accessible resource IDs for a given resource type.
|
||||
* @param resourceType - The type of resource
|
||||
* @param requiredPermissions - Required permission bits
|
||||
*/
|
||||
async function findPublicResourceIds(
|
||||
resourceType: string,
|
||||
requiredPermissions: number,
|
||||
): Promise<Types.ObjectId[]> {
|
||||
const AclEntry = mongoose.models.AclEntry as Model<IAclEntry>;
|
||||
return AclEntry.find({
|
||||
principalType: PrincipalType.PUBLIC,
|
||||
resourceType,
|
||||
permBits: { $bitsAllSet: requiredPermissions },
|
||||
}).distinct('resourceId');
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs an aggregation pipeline on the AclEntry collection.
|
||||
* @param pipeline - MongoDB aggregation pipeline stages
|
||||
*/
|
||||
async function aggregateAclEntries(pipeline: PipelineStage[]) {
|
||||
const AclEntry = mongoose.models.AclEntry as Model<IAclEntry>;
|
||||
return AclEntry.aggregate(pipeline);
|
||||
}
|
||||
|
||||
return {
|
||||
findEntriesByPrincipal,
|
||||
findEntriesByResource,
|
||||
|
|
@ -360,6 +418,10 @@ export function createAclEntryMethods(mongoose: typeof import('mongoose')) {
|
|||
revokePermission,
|
||||
modifyPermissionBits,
|
||||
findAccessibleResources,
|
||||
deleteAclEntries,
|
||||
bulkWriteAclEntries,
|
||||
findPublicResourceIds,
|
||||
aggregateAclEntries,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
77
packages/data-schemas/src/methods/action.ts
Normal file
77
packages/data-schemas/src/methods/action.ts
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import type { FilterQuery, Model } from 'mongoose';
|
||||
import type { IAction } from '~/types';
|
||||
|
||||
const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret'] as const;
|
||||
|
||||
export function createActionMethods(mongoose: typeof import('mongoose')) {
|
||||
/**
|
||||
* Update an action with new data without overwriting existing properties,
|
||||
* or create a new action if it doesn't exist.
|
||||
*/
|
||||
async function updateAction(
|
||||
searchParams: FilterQuery<IAction>,
|
||||
updateData: Partial<IAction>,
|
||||
): Promise<IAction | null> {
|
||||
const Action = mongoose.models.Action as Model<IAction>;
|
||||
const options = { new: true, upsert: true };
|
||||
return (await Action.findOneAndUpdate(
|
||||
searchParams,
|
||||
updateData,
|
||||
options,
|
||||
).lean()) as IAction | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves all actions that match the given search parameters.
|
||||
*/
|
||||
async function getActions(
|
||||
searchParams: FilterQuery<IAction>,
|
||||
includeSensitive = false,
|
||||
): Promise<IAction[]> {
|
||||
const Action = mongoose.models.Action as Model<IAction>;
|
||||
const actions = (await Action.find(searchParams).lean()) as IAction[];
|
||||
|
||||
if (!includeSensitive) {
|
||||
for (let i = 0; i < actions.length; i++) {
|
||||
const metadata = actions[i].metadata;
|
||||
if (!metadata) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const field of sensitiveFields) {
|
||||
if (metadata[field]) {
|
||||
delete metadata[field];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return actions;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes an action by params.
|
||||
*/
|
||||
async function deleteAction(searchParams: FilterQuery<IAction>): Promise<IAction | null> {
|
||||
const Action = mongoose.models.Action as Model<IAction>;
|
||||
return (await Action.findOneAndDelete(searchParams).lean()) as IAction | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes actions by params.
|
||||
*/
|
||||
async function deleteActions(searchParams: FilterQuery<IAction>): Promise<number> {
|
||||
const Action = mongoose.models.Action as Model<IAction>;
|
||||
const result = await Action.deleteMany(searchParams);
|
||||
return result.deletedCount;
|
||||
}
|
||||
|
||||
return {
|
||||
getActions,
|
||||
updateAction,
|
||||
deleteAction,
|
||||
deleteActions,
|
||||
};
|
||||
}
|
||||
|
||||
export type ActionMethods = ReturnType<typeof createActionMethods>;
|
||||
3287
packages/data-schemas/src/methods/agent.spec.ts
Normal file
3287
packages/data-schemas/src/methods/agent.spec.ts
Normal file
File diff suppressed because it is too large
Load diff
716
packages/data-schemas/src/methods/agent.ts
Normal file
716
packages/data-schemas/src/methods/agent.ts
Normal file
|
|
@ -0,0 +1,716 @@
|
|||
import crypto from 'node:crypto';
|
||||
import type { FilterQuery, Model, Types } from 'mongoose';
|
||||
import { Constants, ResourceType, actionDelimiter } from 'librechat-data-provider';
|
||||
import logger from '~/config/winston';
|
||||
import type { IAgent } from '~/types';
|
||||
|
||||
const { mcp_delimiter } = Constants;
|
||||
|
||||
export interface AgentDeps {
|
||||
/** Removes all ACL permissions for a resource. Injected from PermissionService. */
|
||||
removeAllPermissions: (params: { resourceType: string; resourceId: unknown }) => Promise<void>;
|
||||
/** Gets actions. Created by createActionMethods. */
|
||||
getActions: (
|
||||
searchParams: FilterQuery<unknown>,
|
||||
includeSensitive?: boolean,
|
||||
) => Promise<unknown[]>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts unique MCP server names from tools array.
|
||||
* Tools format: "toolName_mcp_serverName" or "sys__server__sys_mcp_serverName"
|
||||
*/
|
||||
function extractMCPServerNames(tools: string[] | undefined | null): string[] {
|
||||
if (!tools || !Array.isArray(tools)) {
|
||||
return [];
|
||||
}
|
||||
const serverNames = new Set<string>();
|
||||
for (const tool of tools) {
|
||||
if (!tool || !tool.includes(mcp_delimiter)) {
|
||||
continue;
|
||||
}
|
||||
const parts = tool.split(mcp_delimiter);
|
||||
if (parts.length >= 2) {
|
||||
serverNames.add(parts[parts.length - 1]);
|
||||
}
|
||||
}
|
||||
return Array.from(serverNames);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a version already exists in the versions array, excluding timestamp and author fields.
|
||||
*/
|
||||
function isDuplicateVersion(
|
||||
updateData: Record<string, unknown>,
|
||||
currentData: Record<string, unknown>,
|
||||
versions: Record<string, unknown>[],
|
||||
actionsHash: string | null = null,
|
||||
): Record<string, unknown> | null {
|
||||
if (!versions || versions.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const excludeFields = [
|
||||
'_id',
|
||||
'id',
|
||||
'createdAt',
|
||||
'updatedAt',
|
||||
'author',
|
||||
'updatedBy',
|
||||
'created_at',
|
||||
'updated_at',
|
||||
'__v',
|
||||
'versions',
|
||||
'actionsHash',
|
||||
];
|
||||
|
||||
const { $push: _$push, $pull: _$pull, $addToSet: _$addToSet, ...directUpdates } = updateData;
|
||||
|
||||
if (Object.keys(directUpdates).length === 0 && !actionsHash) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const wouldBeVersion = { ...currentData, ...directUpdates } as Record<string, unknown>;
|
||||
const lastVersion = versions[versions.length - 1] as Record<string, unknown>;
|
||||
|
||||
if (actionsHash && lastVersion.actionsHash !== actionsHash) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const allFields = new Set([...Object.keys(wouldBeVersion), ...Object.keys(lastVersion)]);
|
||||
const importantFields = Array.from(allFields).filter((field) => !excludeFields.includes(field));
|
||||
|
||||
let isMatch = true;
|
||||
for (const field of importantFields) {
|
||||
const wouldBeValue = wouldBeVersion[field];
|
||||
const lastVersionValue = lastVersion[field];
|
||||
|
||||
if (!wouldBeValue && !lastVersionValue) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle arrays
|
||||
if (Array.isArray(wouldBeValue) || Array.isArray(lastVersionValue)) {
|
||||
let wouldBeArr: unknown[];
|
||||
if (Array.isArray(wouldBeValue)) {
|
||||
wouldBeArr = wouldBeValue;
|
||||
} else if (wouldBeValue == null) {
|
||||
wouldBeArr = [];
|
||||
} else {
|
||||
wouldBeArr = [wouldBeValue];
|
||||
}
|
||||
|
||||
let lastVersionArr: unknown[];
|
||||
if (Array.isArray(lastVersionValue)) {
|
||||
lastVersionArr = lastVersionValue;
|
||||
} else if (lastVersionValue == null) {
|
||||
lastVersionArr = [];
|
||||
} else {
|
||||
lastVersionArr = [lastVersionValue];
|
||||
}
|
||||
|
||||
if (wouldBeArr.length !== lastVersionArr.length) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (wouldBeArr.length > 0 && typeof wouldBeArr[0] === 'object' && wouldBeArr[0] !== null) {
|
||||
const sortedWouldBe = [...wouldBeArr].map((item) => JSON.stringify(item)).sort();
|
||||
const sortedVersion = [...lastVersionArr].map((item) => JSON.stringify(item)).sort();
|
||||
|
||||
if (!sortedWouldBe.every((item, i) => item === sortedVersion[i])) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const sortedWouldBe = [...wouldBeArr].sort() as string[];
|
||||
const sortedVersion = [...lastVersionArr].sort() as string[];
|
||||
|
||||
if (!sortedWouldBe.every((item, i) => item === sortedVersion[i])) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle objects
|
||||
else if (typeof wouldBeValue === 'object' && wouldBeValue !== null) {
|
||||
const lastVersionObj =
|
||||
typeof lastVersionValue === 'object' && lastVersionValue !== null ? lastVersionValue : {};
|
||||
|
||||
const wouldBeKeys = Object.keys(wouldBeValue as Record<string, unknown>);
|
||||
const lastVersionKeys = Object.keys(lastVersionObj as Record<string, unknown>);
|
||||
|
||||
if (wouldBeKeys.length === 0 && lastVersionKeys.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (JSON.stringify(wouldBeValue) !== JSON.stringify(lastVersionObj)) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Handle primitive values
|
||||
else {
|
||||
if (wouldBeValue !== lastVersionValue) {
|
||||
if (
|
||||
typeof wouldBeValue === 'boolean' &&
|
||||
wouldBeValue === false &&
|
||||
lastVersionValue === undefined
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
if (
|
||||
typeof wouldBeValue === 'string' &&
|
||||
wouldBeValue === '' &&
|
||||
lastVersionValue === undefined
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return isMatch ? lastVersion : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a hash of action metadata for version comparison.
|
||||
*/
|
||||
async function generateActionMetadataHash(
|
||||
actionIds: string[] | null | undefined,
|
||||
actions: Array<{ action_id: string; metadata: Record<string, unknown> | null }>,
|
||||
): Promise<string> {
|
||||
if (!actionIds || actionIds.length === 0) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const actionMap = new Map<string, Record<string, unknown> | null>();
|
||||
actions.forEach((action) => {
|
||||
actionMap.set(action.action_id, action.metadata);
|
||||
});
|
||||
|
||||
const sortedActionIds = [...actionIds].sort();
|
||||
|
||||
const metadataString = sortedActionIds
|
||||
.map((actionFullId) => {
|
||||
const parts = actionFullId.split(actionDelimiter);
|
||||
const actionId = parts[1];
|
||||
|
||||
const metadata = actionMap.get(actionId);
|
||||
if (!metadata) {
|
||||
return `${actionId}:null`;
|
||||
}
|
||||
|
||||
const sortedKeys = Object.keys(metadata).sort();
|
||||
const metadataStr = sortedKeys
|
||||
.map((key) => `${key}:${JSON.stringify(metadata[key])}`)
|
||||
.join(',');
|
||||
return `${actionId}:{${metadataStr}}`;
|
||||
})
|
||||
.join(';');
|
||||
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(metadataString);
|
||||
const hashBuffer = await crypto.webcrypto.subtle.digest('SHA-256', data);
|
||||
const hashArray = Array.from(new Uint8Array(hashBuffer));
|
||||
const hashHex = hashArray.map((b) => b.toString(16).padStart(2, '0')).join('');
|
||||
|
||||
return hashHex;
|
||||
}
|
||||
|
||||
export function createAgentMethods(mongoose: typeof import('mongoose'), deps: AgentDeps) {
|
||||
const { removeAllPermissions, getActions } = deps;
|
||||
|
||||
/**
|
||||
* Create an agent with the provided data.
|
||||
*/
|
||||
async function createAgent(agentData: Record<string, unknown>): Promise<IAgent> {
|
||||
const Agent = mongoose.models.Agent as Model<IAgent>;
|
||||
const { author: _author, ...versionData } = agentData;
|
||||
const timestamp = new Date();
|
||||
const initialAgentData = {
|
||||
...agentData,
|
||||
versions: [
|
||||
{
|
||||
...versionData,
|
||||
createdAt: timestamp,
|
||||
updatedAt: timestamp,
|
||||
},
|
||||
],
|
||||
category: (agentData.category as string) || 'general',
|
||||
mcpServerNames: extractMCPServerNames(agentData.tools as string[] | undefined),
|
||||
};
|
||||
|
||||
return (await Agent.create(initialAgentData)).toObject() as IAgent;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an agent document based on the provided search parameter.
|
||||
*/
|
||||
async function getAgent(searchParameter: FilterQuery<IAgent>): Promise<IAgent | null> {
|
||||
const Agent = mongoose.models.Agent as Model<IAgent>;
|
||||
return (await Agent.findOne(searchParameter).lean()) as IAgent | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get multiple agent documents based on the provided search parameters.
|
||||
*/
|
||||
async function getAgents(searchParameter: FilterQuery<IAgent>): Promise<IAgent[]> {
|
||||
const Agent = mongoose.models.Agent as Model<IAgent>;
|
||||
return (await Agent.find(searchParameter).lean()) as IAgent[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Update an agent with new data without overwriting existing properties,
|
||||
* or create a new agent if it doesn't exist.
|
||||
* When an agent is updated, a copy of the current state will be saved to the versions array.
|
||||
*/
|
||||
async function updateAgent(
|
||||
searchParameter: FilterQuery<IAgent>,
|
||||
updateData: Record<string, unknown>,
|
||||
options: {
|
||||
updatingUserId?: string | null;
|
||||
forceVersion?: boolean;
|
||||
skipVersioning?: boolean;
|
||||
} = {},
|
||||
): Promise<IAgent | null> {
|
||||
const Agent = mongoose.models.Agent as Model<IAgent>;
|
||||
const { updatingUserId = null, forceVersion = false, skipVersioning = false } = options;
|
||||
const mongoOptions = { new: true, upsert: false };
|
||||
|
||||
const currentAgent = await Agent.findOne(searchParameter);
|
||||
if (currentAgent) {
|
||||
const {
|
||||
__v,
|
||||
_id,
|
||||
id: __id,
|
||||
versions,
|
||||
author: _author,
|
||||
...versionData
|
||||
} = currentAgent.toObject() as unknown as Record<string, unknown>;
|
||||
const { $push, $pull, $addToSet, ...directUpdates } = updateData;
|
||||
|
||||
// Sync mcpServerNames when tools are updated
|
||||
if ((directUpdates as Record<string, unknown>).tools !== undefined) {
|
||||
const mcpServerNames = extractMCPServerNames(
|
||||
(directUpdates as Record<string, unknown>).tools as string[],
|
||||
);
|
||||
(directUpdates as Record<string, unknown>).mcpServerNames = mcpServerNames;
|
||||
updateData.mcpServerNames = mcpServerNames;
|
||||
}
|
||||
|
||||
let actionsHash: string | null = null;
|
||||
|
||||
// Generate actions hash if agent has actions
|
||||
if (currentAgent.actions && currentAgent.actions.length > 0) {
|
||||
const actionIds = currentAgent.actions
|
||||
.map((action: string) => {
|
||||
const parts = action.split(actionDelimiter);
|
||||
return parts[1];
|
||||
})
|
||||
.filter(Boolean);
|
||||
|
||||
if (actionIds.length > 0) {
|
||||
try {
|
||||
const actions = await getActions({ action_id: { $in: actionIds } }, true);
|
||||
|
||||
actionsHash = await generateActionMetadataHash(
|
||||
currentAgent.actions,
|
||||
actions as Array<{ action_id: string; metadata: Record<string, unknown> | null }>,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('Error fetching actions for hash generation:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const shouldCreateVersion =
|
||||
!skipVersioning &&
|
||||
(forceVersion || Object.keys(directUpdates).length > 0 || $push || $pull || $addToSet);
|
||||
|
||||
if (shouldCreateVersion) {
|
||||
const duplicateVersion = isDuplicateVersion(
|
||||
updateData,
|
||||
versionData,
|
||||
versions as Record<string, unknown>[],
|
||||
actionsHash,
|
||||
);
|
||||
if (duplicateVersion && !forceVersion) {
|
||||
const agentObj = currentAgent.toObject() as IAgent & {
|
||||
version?: number;
|
||||
versions?: unknown[];
|
||||
};
|
||||
agentObj.version = (versions as unknown[]).length;
|
||||
return agentObj;
|
||||
}
|
||||
}
|
||||
|
||||
const versionEntry: Record<string, unknown> = {
|
||||
...versionData,
|
||||
...directUpdates,
|
||||
updatedAt: new Date(),
|
||||
};
|
||||
|
||||
if (actionsHash) {
|
||||
versionEntry.actionsHash = actionsHash;
|
||||
}
|
||||
|
||||
if (updatingUserId) {
|
||||
versionEntry.updatedBy = new mongoose.Types.ObjectId(updatingUserId);
|
||||
}
|
||||
|
||||
if (shouldCreateVersion) {
|
||||
updateData.$push = {
|
||||
...(($push as Record<string, unknown>) || {}),
|
||||
versions: versionEntry,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return (await Agent.findOneAndUpdate(
|
||||
searchParameter,
|
||||
updateData,
|
||||
mongoOptions,
|
||||
).lean()) as IAgent | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Modifies an agent with the resource file id.
|
||||
*/
|
||||
async function addAgentResourceFile({
|
||||
agent_id,
|
||||
tool_resource,
|
||||
file_id,
|
||||
updatingUserId,
|
||||
}: {
|
||||
agent_id: string;
|
||||
tool_resource: string;
|
||||
file_id: string;
|
||||
updatingUserId?: string;
|
||||
}): Promise<IAgent> {
|
||||
const Agent = mongoose.models.Agent as Model<IAgent>;
|
||||
const searchParameter = { id: agent_id };
|
||||
const agent = await getAgent(searchParameter);
|
||||
if (!agent) {
|
||||
throw new Error('Agent not found for adding resource file');
|
||||
}
|
||||
const fileIdsPath = `tool_resources.${tool_resource}.file_ids`;
|
||||
await Agent.updateOne(
|
||||
{
|
||||
id: agent_id,
|
||||
[`${fileIdsPath}`]: { $exists: false },
|
||||
},
|
||||
{
|
||||
$set: {
|
||||
[`${fileIdsPath}`]: [],
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const updateDataObj: Record<string, unknown> = {
|
||||
$addToSet: {
|
||||
tools: tool_resource,
|
||||
[fileIdsPath]: file_id,
|
||||
},
|
||||
};
|
||||
|
||||
const updatedAgent = await updateAgent(searchParameter, updateDataObj, {
|
||||
updatingUserId,
|
||||
});
|
||||
if (updatedAgent) {
|
||||
return updatedAgent;
|
||||
} else {
|
||||
throw new Error('Agent not found for adding resource file');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes multiple resource files from an agent using atomic operations.
|
||||
*/
|
||||
async function removeAgentResourceFiles({
|
||||
agent_id,
|
||||
files,
|
||||
}: {
|
||||
agent_id: string;
|
||||
files: Array<{ tool_resource: string; file_id: string }>;
|
||||
}): Promise<IAgent> {
|
||||
const Agent = mongoose.models.Agent as Model<IAgent>;
|
||||
const searchParameter = { id: agent_id };
|
||||
|
||||
const filesByResource = files.reduce(
|
||||
(acc: Record<string, string[]>, { tool_resource, file_id }) => {
|
||||
if (!acc[tool_resource]) {
|
||||
acc[tool_resource] = [];
|
||||
}
|
||||
acc[tool_resource].push(file_id);
|
||||
return acc;
|
||||
},
|
||||
{},
|
||||
);
|
||||
|
||||
const pullAllOps: Record<string, string[]> = {};
|
||||
for (const [resource, fileIds] of Object.entries(filesByResource)) {
|
||||
const fileIdsPath = `tool_resources.${resource}.file_ids`;
|
||||
pullAllOps[fileIdsPath] = fileIds;
|
||||
}
|
||||
|
||||
const updatePullData = { $pullAll: pullAllOps };
|
||||
const agentAfterPull = (await Agent.findOneAndUpdate(searchParameter, updatePullData, {
|
||||
new: true,
|
||||
}).lean()) as IAgent | null;
|
||||
|
||||
if (!agentAfterPull) {
|
||||
const agentExists = await getAgent(searchParameter);
|
||||
if (!agentExists) {
|
||||
throw new Error('Agent not found for removing resource files');
|
||||
}
|
||||
throw new Error('Failed to update agent during file removal (pull step)');
|
||||
}
|
||||
|
||||
return agentAfterPull;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes an agent based on the provided search parameter.
|
||||
*/
|
||||
async function deleteAgent(searchParameter: FilterQuery<IAgent>): Promise<IAgent | null> {
|
||||
const Agent = mongoose.models.Agent as Model<IAgent>;
|
||||
const User = mongoose.models.User as Model<unknown>;
|
||||
const agent = await Agent.findOneAndDelete(searchParameter);
|
||||
if (agent) {
|
||||
await Promise.all([
|
||||
removeAllPermissions({
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
}),
|
||||
removeAllPermissions({
|
||||
resourceType: ResourceType.REMOTE_AGENT,
|
||||
resourceId: agent._id,
|
||||
}),
|
||||
]);
|
||||
try {
|
||||
await Agent.updateMany(
|
||||
{ 'edges.to': (agent as unknown as { id: string }).id },
|
||||
{ $pull: { edges: { to: (agent as unknown as { id: string }).id } } },
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('[deleteAgent] Error removing agent from handoff edges', error);
|
||||
}
|
||||
try {
|
||||
await User.updateMany(
|
||||
{ 'favorites.agentId': (agent as unknown as { id: string }).id },
|
||||
{ $pull: { favorites: { agentId: (agent as unknown as { id: string }).id } } },
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('[deleteAgent] Error removing agent from user favorites', error);
|
||||
}
|
||||
}
|
||||
return agent ? (agent.toObject() as IAgent) : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes all agents created by a specific user.
|
||||
*/
|
||||
async function deleteUserAgents(userId: string): Promise<void> {
|
||||
const Agent = mongoose.models.Agent as Model<IAgent>;
|
||||
const AclEntry = mongoose.models.AclEntry as Model<unknown>;
|
||||
const User = mongoose.models.User as Model<unknown>;
|
||||
|
||||
try {
|
||||
const userAgents = await getAgents({ author: userId });
|
||||
|
||||
if (userAgents.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const agentIds = userAgents.map((agent) => agent.id);
|
||||
const agentObjectIds = userAgents.map(
|
||||
(agent) => (agent as unknown as { _id: Types.ObjectId })._id,
|
||||
);
|
||||
|
||||
await AclEntry.deleteMany({
|
||||
resourceType: { $in: [ResourceType.AGENT, ResourceType.REMOTE_AGENT] },
|
||||
resourceId: { $in: agentObjectIds },
|
||||
});
|
||||
|
||||
try {
|
||||
await User.updateMany(
|
||||
{ 'favorites.agentId': { $in: agentIds } },
|
||||
{ $pull: { favorites: { agentId: { $in: agentIds } } } },
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('[deleteUserAgents] Error removing agents from user favorites', error);
|
||||
}
|
||||
|
||||
await Agent.deleteMany({ author: userId });
|
||||
} catch (error) {
|
||||
logger.error('[deleteUserAgents] General error:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get agents by accessible IDs with optional cursor-based pagination.
|
||||
*/
|
||||
async function getListAgentsByAccess({
|
||||
accessibleIds = [],
|
||||
otherParams = {},
|
||||
limit = null,
|
||||
after = null,
|
||||
}: {
|
||||
accessibleIds?: Types.ObjectId[];
|
||||
otherParams?: Record<string, unknown>;
|
||||
limit?: number | null;
|
||||
after?: string | null;
|
||||
}): Promise<{
|
||||
object: string;
|
||||
data: Array<Record<string, unknown>>;
|
||||
first_id: string | null;
|
||||
last_id: string | null;
|
||||
has_more: boolean;
|
||||
after: string | null;
|
||||
}> {
|
||||
const Agent = mongoose.models.Agent as Model<IAgent>;
|
||||
const isPaginated = limit !== null && limit !== undefined;
|
||||
const normalizedLimit = isPaginated
|
||||
? Math.min(Math.max(1, parseInt(String(limit)) || 20), 100)
|
||||
: null;
|
||||
|
||||
const baseQuery: Record<string, unknown> = {
|
||||
...otherParams,
|
||||
_id: { $in: accessibleIds },
|
||||
};
|
||||
|
||||
if (after) {
|
||||
try {
|
||||
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
|
||||
const { updatedAt, _id } = cursor;
|
||||
|
||||
const cursorCondition = {
|
||||
$or: [
|
||||
{ updatedAt: { $lt: new Date(updatedAt) } },
|
||||
{
|
||||
updatedAt: new Date(updatedAt),
|
||||
_id: { $gt: new mongoose.Types.ObjectId(_id) },
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
if (Object.keys(baseQuery).length > 0) {
|
||||
baseQuery.$and = [{ ...baseQuery }, cursorCondition];
|
||||
Object.keys(baseQuery).forEach((key) => {
|
||||
if (key !== '$and') delete baseQuery[key];
|
||||
});
|
||||
} else {
|
||||
Object.assign(baseQuery, cursorCondition);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Invalid cursor:', (error as Error).message);
|
||||
}
|
||||
}
|
||||
|
||||
let query = Agent.find(baseQuery, {
|
||||
id: 1,
|
||||
_id: 1,
|
||||
name: 1,
|
||||
avatar: 1,
|
||||
author: 1,
|
||||
description: 1,
|
||||
updatedAt: 1,
|
||||
category: 1,
|
||||
support_contact: 1,
|
||||
is_promoted: 1,
|
||||
}).sort({ updatedAt: -1, _id: 1 });
|
||||
|
||||
if (isPaginated && normalizedLimit) {
|
||||
query = query.limit(normalizedLimit + 1);
|
||||
}
|
||||
|
||||
const agents = (await query.lean()) as Array<Record<string, unknown>>;
|
||||
|
||||
const hasMore = isPaginated && normalizedLimit ? agents.length > normalizedLimit : false;
|
||||
const data = (isPaginated && normalizedLimit ? agents.slice(0, normalizedLimit) : agents).map(
|
||||
(agent) => {
|
||||
if (agent.author) {
|
||||
agent.author = (agent.author as Types.ObjectId).toString();
|
||||
}
|
||||
return agent;
|
||||
},
|
||||
);
|
||||
|
||||
let nextCursor: string | null = null;
|
||||
if (isPaginated && hasMore && data.length > 0 && normalizedLimit) {
|
||||
const lastAgent = agents[normalizedLimit - 1];
|
||||
nextCursor = Buffer.from(
|
||||
JSON.stringify({
|
||||
updatedAt: (lastAgent.updatedAt as Date).toISOString(),
|
||||
_id: (lastAgent._id as Types.ObjectId).toString(),
|
||||
}),
|
||||
).toString('base64');
|
||||
}
|
||||
|
||||
return {
|
||||
object: 'list',
|
||||
data,
|
||||
first_id: data.length > 0 ? (data[0].id as string) : null,
|
||||
last_id: data.length > 0 ? (data[data.length - 1].id as string) : null,
|
||||
has_more: hasMore,
|
||||
after: nextCursor,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Reverts an agent to a specific version in its version history.
|
||||
*/
|
||||
async function revertAgentVersion(
|
||||
searchParameter: FilterQuery<IAgent>,
|
||||
versionIndex: number,
|
||||
): Promise<IAgent> {
|
||||
const Agent = mongoose.models.Agent as Model<IAgent>;
|
||||
const agent = await Agent.findOne(searchParameter);
|
||||
if (!agent) {
|
||||
throw new Error('Agent not found');
|
||||
}
|
||||
|
||||
if (!agent.versions || !agent.versions[versionIndex]) {
|
||||
throw new Error(`Version ${versionIndex} not found`);
|
||||
}
|
||||
|
||||
const revertToVersion = { ...(agent.versions[versionIndex] as Record<string, unknown>) };
|
||||
delete revertToVersion._id;
|
||||
delete revertToVersion.id;
|
||||
delete revertToVersion.versions;
|
||||
delete revertToVersion.author;
|
||||
delete revertToVersion.updatedBy;
|
||||
|
||||
return (await Agent.findOneAndUpdate(searchParameter, revertToVersion, {
|
||||
new: true,
|
||||
}).lean()) as IAgent;
|
||||
}
|
||||
|
||||
/**
|
||||
* Counts the number of promoted agents.
|
||||
*/
|
||||
async function countPromotedAgents(): Promise<number> {
|
||||
const Agent = mongoose.models.Agent as Model<IAgent>;
|
||||
return await Agent.countDocuments({ is_promoted: true });
|
||||
}
|
||||
|
||||
return {
|
||||
createAgent,
|
||||
getAgent,
|
||||
getAgents,
|
||||
updateAgent,
|
||||
deleteAgent,
|
||||
deleteUserAgents,
|
||||
revertAgentVersion,
|
||||
countPromotedAgents,
|
||||
addAgentResourceFile,
|
||||
removeAgentResourceFiles,
|
||||
getListAgentsByAccess,
|
||||
generateActionMetadataHash,
|
||||
};
|
||||
}
|
||||
|
||||
export type AgentMethods = ReturnType<typeof createAgentMethods>;
|
||||
69
packages/data-schemas/src/methods/assistant.ts
Normal file
69
packages/data-schemas/src/methods/assistant.ts
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
import type { FilterQuery, Model } from 'mongoose';
|
||||
import type { IAssistant } from '~/types';
|
||||
|
||||
export function createAssistantMethods(mongoose: typeof import('mongoose')) {
|
||||
/**
|
||||
* Update an assistant with new data without overwriting existing properties,
|
||||
* or create a new assistant if it doesn't exist.
|
||||
*/
|
||||
async function updateAssistantDoc(
|
||||
searchParams: FilterQuery<IAssistant>,
|
||||
updateData: Partial<IAssistant>,
|
||||
): Promise<IAssistant | null> {
|
||||
const Assistant = mongoose.models.Assistant as Model<IAssistant>;
|
||||
const options = { new: true, upsert: true };
|
||||
return (await Assistant.findOneAndUpdate(
|
||||
searchParams,
|
||||
updateData,
|
||||
options,
|
||||
).lean()) as IAssistant | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves an assistant document based on the provided search params.
|
||||
*/
|
||||
async function getAssistant(searchParams: FilterQuery<IAssistant>): Promise<IAssistant | null> {
|
||||
const Assistant = mongoose.models.Assistant as Model<IAssistant>;
|
||||
return (await Assistant.findOne(searchParams).lean()) as IAssistant | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves all assistants that match the given search parameters.
|
||||
*/
|
||||
async function getAssistants(
|
||||
searchParams: FilterQuery<IAssistant>,
|
||||
select: string | Record<string, number> | null = null,
|
||||
): Promise<IAssistant[]> {
|
||||
const Assistant = mongoose.models.Assistant as Model<IAssistant>;
|
||||
const query = Assistant.find(searchParams);
|
||||
|
||||
return (await (select ? query.select(select) : query).lean()) as IAssistant[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes an assistant based on the provided search params.
|
||||
*/
|
||||
async function deleteAssistant(searchParams: FilterQuery<IAssistant>) {
|
||||
const Assistant = mongoose.models.Assistant as Model<IAssistant>;
|
||||
return await Assistant.findOneAndDelete(searchParams);
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes all assistants matching the given search parameters.
|
||||
*/
|
||||
async function deleteAssistants(searchParams: FilterQuery<IAssistant>): Promise<number> {
|
||||
const Assistant = mongoose.models.Assistant as Model<IAssistant>;
|
||||
const result = await Assistant.deleteMany(searchParams);
|
||||
return result.deletedCount;
|
||||
}
|
||||
|
||||
return {
|
||||
updateAssistantDoc,
|
||||
deleteAssistant,
|
||||
deleteAssistants,
|
||||
getAssistants,
|
||||
getAssistant,
|
||||
};
|
||||
}
|
||||
|
||||
export type AssistantMethods = ReturnType<typeof createAssistantMethods>;
|
||||
33
packages/data-schemas/src/methods/banner.ts
Normal file
33
packages/data-schemas/src/methods/banner.ts
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import type { Model } from 'mongoose';
|
||||
import logger from '~/config/winston';
|
||||
import type { IBanner, IUser } from '~/types';
|
||||
|
||||
export function createBannerMethods(mongoose: typeof import('mongoose')) {
|
||||
/**
|
||||
* Retrieves the current active banner.
|
||||
*/
|
||||
async function getBanner(user?: IUser | null): Promise<IBanner | null> {
|
||||
try {
|
||||
const Banner = mongoose.models.Banner as Model<IBanner>;
|
||||
const now = new Date();
|
||||
const banner = (await Banner.findOne({
|
||||
displayFrom: { $lte: now },
|
||||
$or: [{ displayTo: { $gte: now } }, { displayTo: null }],
|
||||
type: 'banner',
|
||||
}).lean()) as IBanner | null;
|
||||
|
||||
if (!banner || banner.isPublic || user != null) {
|
||||
return banner;
|
||||
}
|
||||
|
||||
return null;
|
||||
} catch (error) {
|
||||
logger.error('[getBanners] Error getting banners', error);
|
||||
throw new Error('Error getting banners');
|
||||
}
|
||||
}
|
||||
|
||||
return { getBanner };
|
||||
}
|
||||
|
||||
export type BannerMethods = ReturnType<typeof createBannerMethods>;
|
||||
33
packages/data-schemas/src/methods/categories.ts
Normal file
33
packages/data-schemas/src/methods/categories.ts
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import logger from '~/config/winston';
|
||||
|
||||
const options = [
|
||||
{ label: 'com_ui_idea', value: 'idea' },
|
||||
{ label: 'com_ui_travel', value: 'travel' },
|
||||
{ label: 'com_ui_teach_or_explain', value: 'teach_or_explain' },
|
||||
{ label: 'com_ui_write', value: 'write' },
|
||||
{ label: 'com_ui_shop', value: 'shop' },
|
||||
{ label: 'com_ui_code', value: 'code' },
|
||||
{ label: 'com_ui_misc', value: 'misc' },
|
||||
{ label: 'com_ui_roleplay', value: 'roleplay' },
|
||||
{ label: 'com_ui_finance', value: 'finance' },
|
||||
] as const;
|
||||
|
||||
export type CategoryOption = { label: string; value: string };
|
||||
|
||||
export function createCategoriesMethods(_mongoose: typeof import('mongoose')) {
|
||||
/**
|
||||
* Retrieves the categories.
|
||||
*/
|
||||
async function getCategories(): Promise<CategoryOption[]> {
|
||||
try {
|
||||
return [...options];
|
||||
} catch (error) {
|
||||
logger.error('Error getting categories', error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
return { getCategories };
|
||||
}
|
||||
|
||||
export type CategoriesMethods = ReturnType<typeof createCategoriesMethods>;
|
||||
909
packages/data-schemas/src/methods/conversation.spec.ts
Normal file
909
packages/data-schemas/src/methods/conversation.spec.ts
Normal file
|
|
@ -0,0 +1,909 @@
|
|||
import mongoose from 'mongoose';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { IConversation } from '../types';
|
||||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import { ConversationMethods, createConversationMethods } from './conversation';
|
||||
import { createModels } from '../models';
|
||||
|
||||
jest.mock('~/config/winston', () => ({
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
info: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
}));
|
||||
|
||||
let mongoServer: InstanceType<typeof MongoMemoryServer>;
|
||||
let Conversation: mongoose.Model<IConversation>;
|
||||
let modelsToCleanup: string[] = [];
|
||||
|
||||
// Mock message methods (same as original test mocking ./Message)
|
||||
const getMessages = jest.fn().mockResolvedValue([]);
|
||||
const deleteMessages = jest.fn().mockResolvedValue({ deletedCount: 0 });
|
||||
|
||||
let methods: ConversationMethods;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
|
||||
const models = createModels(mongoose);
|
||||
modelsToCleanup = Object.keys(models);
|
||||
Object.assign(mongoose.models, models);
|
||||
Conversation = mongoose.models.Conversation as mongoose.Model<IConversation>;
|
||||
|
||||
methods = createConversationMethods(mongoose, { getMessages, deleteMessages });
|
||||
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
|
||||
for (const modelName of modelsToCleanup) {
|
||||
if (mongoose.models[modelName]) {
|
||||
delete mongoose.models[modelName];
|
||||
}
|
||||
}
|
||||
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
const saveConvo = (...args: Parameters<ConversationMethods['saveConvo']>) =>
|
||||
methods.saveConvo(...args) as Promise<IConversation | null>;
|
||||
const getConvo = (...args: Parameters<ConversationMethods['getConvo']>) =>
|
||||
methods.getConvo(...args);
|
||||
const getConvoTitle = (...args: Parameters<ConversationMethods['getConvoTitle']>) =>
|
||||
methods.getConvoTitle(...args);
|
||||
const getConvoFiles = (...args: Parameters<ConversationMethods['getConvoFiles']>) =>
|
||||
methods.getConvoFiles(...args);
|
||||
const deleteConvos = (...args: Parameters<ConversationMethods['deleteConvos']>) =>
|
||||
methods.deleteConvos(...args);
|
||||
const getConvosByCursor = (...args: Parameters<ConversationMethods['getConvosByCursor']>) =>
|
||||
methods.getConvosByCursor(...args);
|
||||
const getConvosQueried = (...args: Parameters<ConversationMethods['getConvosQueried']>) =>
|
||||
methods.getConvosQueried(...args);
|
||||
const deleteNullOrEmptyConversations = (
|
||||
...args: Parameters<ConversationMethods['deleteNullOrEmptyConversations']>
|
||||
) => methods.deleteNullOrEmptyConversations(...args);
|
||||
const searchConversation = (...args: Parameters<ConversationMethods['searchConversation']>) =>
|
||||
methods.searchConversation(...args);
|
||||
|
||||
describe('Conversation Operations', () => {
|
||||
let mockCtx: {
|
||||
userId: string;
|
||||
isTemporary?: boolean;
|
||||
interfaceConfig?: { temporaryChatRetention?: number };
|
||||
};
|
||||
let mockConversationData: {
|
||||
conversationId: string;
|
||||
title: string;
|
||||
endpoint: string;
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
// Clear database
|
||||
await Conversation.deleteMany({});
|
||||
|
||||
// Reset mocks
|
||||
jest.clearAllMocks();
|
||||
getMessages.mockResolvedValue([]);
|
||||
deleteMessages.mockResolvedValue({ deletedCount: 0 });
|
||||
|
||||
mockCtx = {
|
||||
userId: 'user123',
|
||||
interfaceConfig: {
|
||||
temporaryChatRetention: 24, // Default 24 hours
|
||||
},
|
||||
};
|
||||
|
||||
mockConversationData = {
|
||||
conversationId: uuidv4(),
|
||||
title: 'Test Conversation',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
};
|
||||
});
|
||||
|
||||
describe('saveConvo', () => {
|
||||
it('should save a conversation for an authenticated user', async () => {
|
||||
const result = await saveConvo(mockCtx, mockConversationData);
|
||||
|
||||
expect(result?.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result?.user).toBe('user123');
|
||||
expect(result?.title).toBe('Test Conversation');
|
||||
expect(result?.endpoint).toBe(EModelEndpoint.openAI);
|
||||
|
||||
// Verify the conversation was actually saved to the database
|
||||
const savedConvo = await Conversation.findOne<IConversation>({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
});
|
||||
expect(savedConvo).toBeTruthy();
|
||||
expect(savedConvo?.title).toBe('Test Conversation');
|
||||
});
|
||||
|
||||
it('should query messages when saving a conversation', async () => {
|
||||
// Mock messages as ObjectIds
|
||||
const mockMessages = [new mongoose.Types.ObjectId(), new mongoose.Types.ObjectId()];
|
||||
getMessages.mockResolvedValue(mockMessages);
|
||||
|
||||
await saveConvo(mockCtx, mockConversationData);
|
||||
|
||||
// Verify that getMessages was called with correct parameters
|
||||
expect(getMessages).toHaveBeenCalledWith(
|
||||
{ conversationId: mockConversationData.conversationId },
|
||||
'_id',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle newConversationId when provided', async () => {
|
||||
const newConversationId = uuidv4();
|
||||
const result = await saveConvo(mockCtx, {
|
||||
...mockConversationData,
|
||||
newConversationId,
|
||||
});
|
||||
|
||||
expect(result?.conversationId).toBe(newConversationId);
|
||||
});
|
||||
|
||||
it('should not create a conversation when noUpsert is true and conversation does not exist', async () => {
|
||||
const nonExistentId = uuidv4();
|
||||
const result = await saveConvo(
|
||||
mockCtx,
|
||||
{ conversationId: nonExistentId, title: 'Ghost Title' },
|
||||
{ noUpsert: true },
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
|
||||
const dbConvo = await Conversation.findOne({ conversationId: nonExistentId });
|
||||
expect(dbConvo).toBeNull();
|
||||
});
|
||||
|
||||
it('should update an existing conversation when noUpsert is true', async () => {
|
||||
await saveConvo(mockCtx, mockConversationData);
|
||||
|
||||
const result = await saveConvo(
|
||||
mockCtx,
|
||||
{ conversationId: mockConversationData.conversationId, title: 'Updated Title' },
|
||||
{ noUpsert: true },
|
||||
);
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.title).toBe('Updated Title');
|
||||
expect(result?.conversationId).toBe(mockConversationData.conversationId);
|
||||
});
|
||||
|
||||
it('should still upsert by default when noUpsert is not provided', async () => {
|
||||
const newId = uuidv4();
|
||||
const result = await saveConvo(mockCtx, {
|
||||
conversationId: newId,
|
||||
title: 'New Conversation',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.conversationId).toBe(newId);
|
||||
expect(result?.title).toBe('New Conversation');
|
||||
});
|
||||
|
||||
it('should handle unsetFields metadata', async () => {
|
||||
const metadata = {
|
||||
unsetFields: { someField: 1 },
|
||||
};
|
||||
|
||||
await saveConvo(mockCtx, mockConversationData, metadata);
|
||||
|
||||
const savedConvo = await Conversation.findOne<IConversation & { someField?: string }>({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
});
|
||||
expect(savedConvo?.someField).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTemporary conversation handling', () => {
|
||||
it('should save a conversation with expiredAt when isTemporary is true', async () => {
|
||||
mockCtx.interfaceConfig = { temporaryChatRetention: 24 };
|
||||
mockCtx.isTemporary = true;
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockCtx, mockConversationData);
|
||||
const afterSave = new Date();
|
||||
|
||||
expect(result?.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result?.expiredAt).toBeDefined();
|
||||
expect(result?.expiredAt).toBeInstanceOf(Date);
|
||||
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result?.expiredAt ?? 0);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
new Date(afterSave.getTime() + 24 * 60 * 60 * 1000 + 1000).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should save a conversation without expiredAt when isTemporary is false', async () => {
|
||||
mockCtx.isTemporary = false;
|
||||
|
||||
const result = await saveConvo(mockCtx, mockConversationData);
|
||||
|
||||
expect(result?.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result?.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should save a conversation without expiredAt when isTemporary is not provided', async () => {
|
||||
mockCtx.isTemporary = undefined;
|
||||
|
||||
const result = await saveConvo(mockCtx, mockConversationData);
|
||||
|
||||
expect(result?.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result?.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use custom retention period from config', async () => {
|
||||
mockCtx.interfaceConfig = { temporaryChatRetention: 48 };
|
||||
mockCtx.isTemporary = true;
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockCtx, mockConversationData);
|
||||
|
||||
expect(result?.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 48 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 48 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result?.expiredAt ?? 0);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle minimum retention period (1 hour)', async () => {
|
||||
// Mock app config with less than minimum retention
|
||||
mockCtx.interfaceConfig = { temporaryChatRetention: 0.5 }; // Half hour - should be clamped to 1 hour
|
||||
mockCtx.isTemporary = true;
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockCtx, mockConversationData);
|
||||
|
||||
expect(result?.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 1 hour in the future (minimum)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 1 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result?.expiredAt ?? 0);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle maximum retention period (8760 hours)', async () => {
|
||||
// Mock app config with more than maximum retention
|
||||
mockCtx.interfaceConfig = { temporaryChatRetention: 10000 }; // Should be clamped to 8760 hours
|
||||
mockCtx.isTemporary = true;
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockCtx, mockConversationData);
|
||||
|
||||
expect(result?.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 8760 hours (1 year) in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 8760 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result?.expiredAt ?? 0);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle missing config gracefully', async () => {
|
||||
// Simulate missing config - should use default retention period
|
||||
mockCtx.interfaceConfig = undefined;
|
||||
mockCtx.isTemporary = true;
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockCtx, mockConversationData);
|
||||
const afterSave = new Date();
|
||||
|
||||
// Should still save the conversation with default retention period (30 days)
|
||||
expect(result?.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result?.expiredAt).toBeDefined();
|
||||
expect(result?.expiredAt).toBeInstanceOf(Date);
|
||||
|
||||
// Verify expiredAt is approximately 30 days in the future (720 hours)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 720 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result?.expiredAt ?? 0);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
new Date(afterSave.getTime() + 720 * 60 * 60 * 1000 + 1000).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use default retention when config is not provided', async () => {
|
||||
// Mock getAppConfig to return empty config
|
||||
mockCtx.interfaceConfig = undefined; // Empty config
|
||||
mockCtx.isTemporary = true;
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockCtx, mockConversationData);
|
||||
|
||||
expect(result?.expiredAt).toBeDefined();
|
||||
|
||||
// Default retention is 30 days (720 hours)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 30 * 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result?.expiredAt ?? 0);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should update expiredAt when saving existing temporary conversation', async () => {
|
||||
// First save a temporary conversation
|
||||
mockCtx.interfaceConfig = { temporaryChatRetention: 24 };
|
||||
mockCtx.isTemporary = true;
|
||||
const firstSave = await saveConvo(mockCtx, mockConversationData);
|
||||
const originalExpiredAt = firstSave?.expiredAt ?? new Date(0);
|
||||
|
||||
// Wait a bit to ensure time difference
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Save again with same conversationId but different title
|
||||
const updatedData = { ...mockConversationData, title: 'Updated Title' };
|
||||
const secondSave = await saveConvo(mockCtx, updatedData);
|
||||
|
||||
// Should update title and create new expiredAt
|
||||
expect(secondSave?.title).toBe('Updated Title');
|
||||
expect(secondSave?.expiredAt).toBeDefined();
|
||||
expect(new Date(secondSave?.expiredAt ?? 0).getTime()).toBeGreaterThan(
|
||||
new Date(originalExpiredAt).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not set expiredAt when updating non-temporary conversation', async () => {
|
||||
// First save a non-temporary conversation
|
||||
mockCtx.isTemporary = false;
|
||||
const firstSave = await saveConvo(mockCtx, mockConversationData);
|
||||
expect(firstSave?.expiredAt).toBeNull();
|
||||
|
||||
// Update without isTemporary flag
|
||||
mockCtx.isTemporary = undefined;
|
||||
const updatedData = { ...mockConversationData, title: 'Updated Title' };
|
||||
const secondSave = await saveConvo(mockCtx, updatedData);
|
||||
|
||||
expect(secondSave?.title).toBe('Updated Title');
|
||||
expect(secondSave?.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should filter out expired conversations in getConvosByCursor', async () => {
|
||||
// Create some test conversations
|
||||
const nonExpiredConvo = await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Non-expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
updatedAt: new Date(),
|
||||
});
|
||||
|
||||
await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Future expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: new Date(Date.now() + 24 * 60 * 60 * 1000), // 24 hours from now
|
||||
updatedAt: new Date(),
|
||||
});
|
||||
|
||||
// Mock Meili search
|
||||
Object.assign(Conversation, { meiliSearch: jest.fn().mockResolvedValue({ hits: [] }) });
|
||||
|
||||
const result = await getConvosByCursor('user123');
|
||||
|
||||
// Should only return conversations with null or non-existent expiredAt
|
||||
expect(result?.conversations).toHaveLength(1);
|
||||
expect(result?.conversations[0]?.conversationId).toBe(nonExpiredConvo.conversationId);
|
||||
});
|
||||
|
||||
it('should filter out expired conversations in getConvosQueried', async () => {
|
||||
// Create test conversations
|
||||
const nonExpiredConvo = await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Non-expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
});
|
||||
|
||||
const expiredConvo = await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: new Date(Date.now() + 24 * 60 * 60 * 1000),
|
||||
});
|
||||
|
||||
const convoIds = [
|
||||
{ conversationId: nonExpiredConvo.conversationId },
|
||||
{ conversationId: expiredConvo.conversationId },
|
||||
];
|
||||
|
||||
const result = await getConvosQueried('user123', convoIds);
|
||||
|
||||
// Should only return the non-expired conversation
|
||||
expect(result?.conversations).toHaveLength(1);
|
||||
expect(result?.conversations[0].conversationId).toBe(nonExpiredConvo.conversationId);
|
||||
expect(result?.convoMap[nonExpiredConvo.conversationId]).toBeDefined();
|
||||
expect(result?.convoMap[expiredConvo.conversationId]).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('searchConversation', () => {
|
||||
it('should find a conversation by conversationId', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'Test',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await searchConversation(mockConversationData.conversationId);
|
||||
|
||||
expect(result).toBeTruthy();
|
||||
expect(result!.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result!.user).toBe('user123');
|
||||
expect((result as unknown as { title?: string }).title).toBeUndefined(); // Only returns conversationId and user
|
||||
});
|
||||
|
||||
it('should return null if conversation not found', async () => {
|
||||
const result = await searchConversation('non-existent-id');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvo', () => {
|
||||
it('should retrieve a conversation for a user', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'Test Conversation',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvo('user123', mockConversationData.conversationId);
|
||||
|
||||
expect(result!.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result!.user).toBe('user123');
|
||||
expect(result!.title).toBe('Test Conversation');
|
||||
});
|
||||
|
||||
it('should return null if conversation not found', async () => {
|
||||
const result = await getConvo('user123', 'non-existent-id');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvoTitle', () => {
|
||||
it('should return the conversation title', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'Test Title',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvoTitle('user123', mockConversationData.conversationId);
|
||||
expect(result).toBe('Test Title');
|
||||
});
|
||||
|
||||
it('should return null if conversation has no title', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: null,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvoTitle('user123', mockConversationData.conversationId);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return "New Chat" if conversation not found', async () => {
|
||||
const result = await getConvoTitle('user123', 'non-existent-id');
|
||||
expect(result).toBe('New Chat');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvoFiles', () => {
|
||||
it('should return conversation files', async () => {
|
||||
const files = ['file1', 'file2'];
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
files,
|
||||
});
|
||||
|
||||
const result = await getConvoFiles(mockConversationData.conversationId);
|
||||
expect(result).toEqual(files);
|
||||
});
|
||||
|
||||
it('should return empty array if no files', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvoFiles(mockConversationData.conversationId);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return empty array if conversation not found', async () => {
|
||||
const result = await getConvoFiles('non-existent-id');
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteConvos', () => {
|
||||
it('should delete conversations and associated messages', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'To Delete',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
deleteMessages.mockResolvedValue({ deletedCount: 5 });
|
||||
|
||||
const result = await deleteConvos('user123', {
|
||||
conversationId: mockConversationData.conversationId,
|
||||
});
|
||||
|
||||
expect(result?.deletedCount).toBe(1);
|
||||
expect(result?.messages.deletedCount).toBe(5);
|
||||
expect(deleteMessages).toHaveBeenCalledWith({
|
||||
conversationId: { $in: [mockConversationData.conversationId] },
|
||||
});
|
||||
|
||||
// Verify conversation was deleted
|
||||
const deletedConvo = await Conversation.findOne({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
});
|
||||
expect(deletedConvo).toBeNull();
|
||||
});
|
||||
|
||||
it('should throw error if no conversations found', async () => {
|
||||
await expect(deleteConvos('user123', { conversationId: 'non-existent' })).rejects.toThrow(
|
||||
'Conversation not found or already deleted.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteNullOrEmptyConversations', () => {
|
||||
it('should delete conversations with null, empty, or missing conversationIds', async () => {
|
||||
// Since conversationId is required by the schema, we can't create documents with null/missing IDs
|
||||
// This test should verify the function works when such documents exist (e.g., from data corruption)
|
||||
|
||||
// For this test, let's create a valid conversation and verify the function doesn't delete it
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user4',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
deleteMessages.mockResolvedValue({ deletedCount: 0 });
|
||||
|
||||
const result = await deleteNullOrEmptyConversations();
|
||||
|
||||
expect(result?.conversations.deletedCount).toBe(0); // No invalid conversations to delete
|
||||
expect(result?.messages.deletedCount).toBe(0);
|
||||
|
||||
// Verify valid conversation remains
|
||||
const remainingConvos = await Conversation.find({});
|
||||
expect(remainingConvos).toHaveLength(1);
|
||||
expect(remainingConvos[0].conversationId).toBe(mockConversationData.conversationId);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle database errors in saveConvo', async () => {
|
||||
// Force a database error by disconnecting
|
||||
await mongoose.disconnect();
|
||||
|
||||
const result = await saveConvo(mockCtx, mockConversationData);
|
||||
|
||||
expect(result).toEqual({ message: 'Error saving conversation' });
|
||||
|
||||
// Reconnect for other tests
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvosByCursor pagination', () => {
|
||||
/**
|
||||
* Helper to create conversations with specific timestamps
|
||||
* Uses collection.insertOne to bypass Mongoose timestamps entirely
|
||||
*/
|
||||
const createConvoWithTimestamps = async (index: number, createdAt: Date, updatedAt: Date) => {
|
||||
const conversationId = uuidv4();
|
||||
// Use collection-level insert to bypass Mongoose timestamps
|
||||
await Conversation.collection.insertOne({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
title: `Conversation ${index}`,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
isArchived: false,
|
||||
createdAt,
|
||||
updatedAt,
|
||||
});
|
||||
return Conversation.findOne({ conversationId }).lean();
|
||||
};
|
||||
|
||||
it('should not skip conversations at page boundaries', async () => {
|
||||
// Create 30 conversations to ensure pagination (limit is 25)
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
const convos: unknown[] = [];
|
||||
|
||||
for (let i = 0; i < 30; i++) {
|
||||
const updatedAt = new Date(baseTime.getTime() - i * 60000); // Each 1 minute apart
|
||||
const convo = await createConvoWithTimestamps(i, updatedAt, updatedAt);
|
||||
convos.push(convo);
|
||||
}
|
||||
|
||||
// Fetch first page
|
||||
const page1 = await getConvosByCursor('user123', { limit: 25 });
|
||||
|
||||
expect(page1.conversations).toHaveLength(25);
|
||||
expect(page1.nextCursor).toBeTruthy();
|
||||
|
||||
// Fetch second page using cursor
|
||||
const page2 = await getConvosByCursor('user123', {
|
||||
limit: 25,
|
||||
cursor: page1.nextCursor,
|
||||
});
|
||||
|
||||
// Should get remaining 5 conversations
|
||||
expect(page2.conversations).toHaveLength(5);
|
||||
expect(page2.nextCursor).toBeNull();
|
||||
|
||||
// Verify no duplicates and no gaps
|
||||
const allIds = [
|
||||
...page1.conversations.map((c: IConversation) => c.conversationId),
|
||||
...page2.conversations.map((c: IConversation) => c.conversationId),
|
||||
];
|
||||
const uniqueIds = new Set(allIds);
|
||||
|
||||
expect(uniqueIds.size).toBe(30); // All 30 conversations accounted for
|
||||
expect(allIds.length).toBe(30); // No duplicates
|
||||
});
|
||||
|
||||
it('should include conversation at exact page boundary (item 26 bug fix)', async () => {
|
||||
// This test specifically verifies the fix for the bug where item 26
|
||||
// (the first item that should appear on page 2) was being skipped
|
||||
|
||||
const baseTime = new Date('2026-01-01T12:00:00.000Z');
|
||||
|
||||
// Create exactly 26 conversations
|
||||
const convos: (IConversation | null)[] = [];
|
||||
for (let i = 0; i < 26; i++) {
|
||||
const updatedAt = new Date(baseTime.getTime() - i * 60000);
|
||||
const convo = await createConvoWithTimestamps(i, updatedAt, updatedAt);
|
||||
convos.push(convo);
|
||||
}
|
||||
|
||||
// The 26th conversation (index 25) should be on page 2
|
||||
const item26 = convos[25];
|
||||
|
||||
// Fetch first page with limit 25
|
||||
const page1 = await getConvosByCursor('user123', { limit: 25 });
|
||||
|
||||
expect(page1.conversations).toHaveLength(25);
|
||||
expect(page1.nextCursor).toBeTruthy();
|
||||
|
||||
// Item 26 should NOT be in page 1
|
||||
const page1Ids = page1.conversations.map((c: IConversation) => c.conversationId);
|
||||
expect(page1Ids).not.toContain(item26!.conversationId);
|
||||
|
||||
// Fetch second page
|
||||
const page2 = await getConvosByCursor('user123', {
|
||||
limit: 25,
|
||||
cursor: page1.nextCursor,
|
||||
});
|
||||
|
||||
// Item 26 MUST be in page 2 (this was the bug - it was being skipped)
|
||||
expect(page2.conversations).toHaveLength(1);
|
||||
expect(page2.conversations[0].conversationId).toBe(item26!.conversationId);
|
||||
});
|
||||
|
||||
it('should sort by updatedAt DESC by default', async () => {
|
||||
// Create conversations with different updatedAt times
|
||||
// Note: createdAt is older but updatedAt varies
|
||||
const convo1 = await createConvoWithTimestamps(
|
||||
1,
|
||||
new Date('2026-01-01T00:00:00.000Z'), // oldest created
|
||||
new Date('2026-01-03T00:00:00.000Z'), // most recently updated
|
||||
);
|
||||
|
||||
const convo2 = await createConvoWithTimestamps(
|
||||
2,
|
||||
new Date('2026-01-02T00:00:00.000Z'), // middle created
|
||||
new Date('2026-01-02T00:00:00.000Z'), // middle updated
|
||||
);
|
||||
|
||||
const convo3 = await createConvoWithTimestamps(
|
||||
3,
|
||||
new Date('2026-01-03T00:00:00.000Z'), // newest created
|
||||
new Date('2026-01-01T00:00:00.000Z'), // oldest updated
|
||||
);
|
||||
|
||||
const result = await getConvosByCursor('user123');
|
||||
|
||||
// Should be sorted by updatedAt DESC (most recent first)
|
||||
expect(result?.conversations).toHaveLength(3);
|
||||
expect(result?.conversations[0].conversationId).toBe(convo1!.conversationId); // Jan 3 updatedAt
|
||||
expect(result?.conversations[1].conversationId).toBe(convo2!.conversationId); // Jan 2 updatedAt
|
||||
expect(result?.conversations[2].conversationId).toBe(convo3!.conversationId); // Jan 1 updatedAt
|
||||
});
|
||||
|
||||
it('should handle conversations with same updatedAt (tie-breaker)', async () => {
|
||||
const sameTime = new Date('2026-01-01T12:00:00.000Z');
|
||||
|
||||
// Create 3 conversations with exact same updatedAt
|
||||
const convo1 = await createConvoWithTimestamps(1, sameTime, sameTime);
|
||||
const convo2 = await createConvoWithTimestamps(2, sameTime, sameTime);
|
||||
const convo3 = await createConvoWithTimestamps(3, sameTime, sameTime);
|
||||
|
||||
const result = await getConvosByCursor('user123');
|
||||
|
||||
// All 3 should be returned (no skipping due to same timestamps)
|
||||
expect(result?.conversations).toHaveLength(3);
|
||||
|
||||
const returnedIds = result?.conversations.map((c: IConversation) => c.conversationId);
|
||||
expect(returnedIds).toContain(convo1!.conversationId);
|
||||
expect(returnedIds).toContain(convo2!.conversationId);
|
||||
expect(returnedIds).toContain(convo3!.conversationId);
|
||||
});
|
||||
|
||||
it('should handle cursor pagination with conversations updated during pagination', async () => {
|
||||
// Simulate the scenario where a conversation is updated between page fetches
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
|
||||
// Create 30 conversations
|
||||
for (let i = 0; i < 30; i++) {
|
||||
const updatedAt = new Date(baseTime.getTime() - i * 60000);
|
||||
await createConvoWithTimestamps(i, updatedAt, updatedAt);
|
||||
}
|
||||
|
||||
// Fetch first page
|
||||
const page1 = await getConvosByCursor('user123', { limit: 25 });
|
||||
expect(page1.conversations).toHaveLength(25);
|
||||
|
||||
// Now update one of the conversations that should be on page 2
|
||||
// to have a newer updatedAt (simulating user activity during pagination)
|
||||
const convosOnPage2 = await Conversation.find({ user: 'user123' })
|
||||
.sort({ updatedAt: -1 })
|
||||
.skip(25)
|
||||
.limit(5);
|
||||
|
||||
if (convosOnPage2.length > 0) {
|
||||
const updatedConvo = convosOnPage2[0];
|
||||
await Conversation.updateOne(
|
||||
{ _id: updatedConvo._id },
|
||||
{ updatedAt: new Date('2026-01-02T00:00:00.000Z') }, // Much newer
|
||||
);
|
||||
}
|
||||
|
||||
// Fetch second page with original cursor
|
||||
const page2 = await getConvosByCursor('user123', {
|
||||
limit: 25,
|
||||
cursor: page1.nextCursor,
|
||||
});
|
||||
|
||||
// The updated conversation might not be in page 2 anymore
|
||||
// (it moved to the front), but we should still get remaining items
|
||||
// without errors and without infinite loops
|
||||
expect(page2.conversations.length).toBeGreaterThanOrEqual(0);
|
||||
});
|
||||
|
||||
it('should correctly decode and use cursor for pagination', async () => {
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
|
||||
// Create 30 conversations
|
||||
for (let i = 0; i < 30; i++) {
|
||||
const updatedAt = new Date(baseTime.getTime() - i * 60000);
|
||||
await createConvoWithTimestamps(i, updatedAt, updatedAt);
|
||||
}
|
||||
|
||||
// Fetch first page
|
||||
const page1 = await getConvosByCursor('user123', { limit: 25 });
|
||||
|
||||
// Decode the cursor to verify it's based on the last RETURNED item
|
||||
const decodedCursor = JSON.parse(
|
||||
Buffer.from(page1.nextCursor as string, 'base64').toString(),
|
||||
);
|
||||
|
||||
// The cursor should match the last item in page1 (item at index 24)
|
||||
const lastReturnedItem = page1.conversations[24] as IConversation;
|
||||
|
||||
expect(new Date(decodedCursor.primary).getTime()).toBe(
|
||||
new Date(lastReturnedItem.updatedAt ?? 0).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should support sortBy createdAt when explicitly requested', async () => {
|
||||
// Create conversations with different timestamps
|
||||
const convo1 = await createConvoWithTimestamps(
|
||||
1,
|
||||
new Date('2026-01-03T00:00:00.000Z'), // newest created
|
||||
new Date('2026-01-01T00:00:00.000Z'), // oldest updated
|
||||
);
|
||||
|
||||
const convo2 = await createConvoWithTimestamps(
|
||||
2,
|
||||
new Date('2026-01-01T00:00:00.000Z'), // oldest created
|
||||
new Date('2026-01-03T00:00:00.000Z'), // newest updated
|
||||
);
|
||||
|
||||
// Verify timestamps were set correctly
|
||||
expect(new Date(convo1!.createdAt ?? 0).getTime()).toBe(
|
||||
new Date('2026-01-03T00:00:00.000Z').getTime(),
|
||||
);
|
||||
expect(new Date(convo2!.createdAt ?? 0).getTime()).toBe(
|
||||
new Date('2026-01-01T00:00:00.000Z').getTime(),
|
||||
);
|
||||
|
||||
const result = await getConvosByCursor('user123', { sortBy: 'createdAt' });
|
||||
|
||||
// Should be sorted by createdAt DESC
|
||||
expect(result?.conversations).toHaveLength(2);
|
||||
expect(result?.conversations[0].conversationId).toBe(convo1!.conversationId); // Jan 3 createdAt
|
||||
expect(result?.conversations[1].conversationId).toBe(convo2!.conversationId); // Jan 1 createdAt
|
||||
});
|
||||
|
||||
it('should handle empty result set gracefully', async () => {
|
||||
const result = await getConvosByCursor('user123');
|
||||
|
||||
expect(result?.conversations).toHaveLength(0);
|
||||
expect(result?.nextCursor).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle exactly limit number of conversations (no next page)', async () => {
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
|
||||
// Create exactly 25 conversations (equal to default limit)
|
||||
for (let i = 0; i < 25; i++) {
|
||||
const updatedAt = new Date(baseTime.getTime() - i * 60000);
|
||||
await createConvoWithTimestamps(i, updatedAt, updatedAt);
|
||||
}
|
||||
|
||||
const result = await getConvosByCursor('user123', { limit: 25 });
|
||||
|
||||
expect(result?.conversations).toHaveLength(25);
|
||||
expect(result?.nextCursor).toBeNull(); // No next page
|
||||
});
|
||||
});
|
||||
});
|
||||
487
packages/data-schemas/src/methods/conversation.ts
Normal file
487
packages/data-schemas/src/methods/conversation.ts
Normal file
|
|
@ -0,0 +1,487 @@
|
|||
import type { FilterQuery, Model, SortOrder } from 'mongoose';
|
||||
import logger from '~/config/winston';
|
||||
import { createTempChatExpirationDate } from '~/utils/tempChatRetention';
|
||||
import type { AppConfig, IConversation } from '~/types';
|
||||
import type { MessageMethods } from './message';
|
||||
import type { DeleteResult } from 'mongoose';
|
||||
|
||||
export interface ConversationMethods {
|
||||
getConvoFiles(conversationId: string): Promise<string[]>;
|
||||
searchConversation(conversationId: string): Promise<IConversation | null>;
|
||||
deleteNullOrEmptyConversations(): Promise<{
|
||||
conversations: { deletedCount?: number };
|
||||
messages: { deletedCount?: number };
|
||||
}>;
|
||||
saveConvo(
|
||||
ctx: { userId: string; isTemporary?: boolean; interfaceConfig?: AppConfig['interfaceConfig'] },
|
||||
data: { conversationId: string; newConversationId?: string; [key: string]: unknown },
|
||||
metadata?: { context?: string; unsetFields?: Record<string, number>; noUpsert?: boolean },
|
||||
): Promise<IConversation | { message: string } | null>;
|
||||
bulkSaveConvos(conversations: Array<Record<string, unknown>>): Promise<unknown>;
|
||||
getConvosByCursor(
|
||||
user: string,
|
||||
options?: {
|
||||
cursor?: string | null;
|
||||
limit?: number;
|
||||
isArchived?: boolean;
|
||||
tags?: string[];
|
||||
search?: string;
|
||||
sortBy?: string;
|
||||
sortDirection?: string;
|
||||
},
|
||||
): Promise<{ conversations: IConversation[]; nextCursor: string | null }>;
|
||||
getConvosQueried(
|
||||
user: string,
|
||||
convoIds: Array<{ conversationId: string }> | null,
|
||||
cursor?: string | null,
|
||||
limit?: number,
|
||||
): Promise<{
|
||||
conversations: IConversation[];
|
||||
nextCursor: string | null;
|
||||
convoMap: Record<string, unknown>;
|
||||
}>;
|
||||
getConvo(user: string, conversationId: string): Promise<IConversation | null>;
|
||||
getConvoTitle(user: string, conversationId: string): Promise<string | null>;
|
||||
deleteConvos(
|
||||
user: string,
|
||||
filter: FilterQuery<IConversation>,
|
||||
): Promise<DeleteResult & { messages: DeleteResult }>;
|
||||
}
|
||||
|
||||
export function createConversationMethods(
|
||||
mongoose: typeof import('mongoose'),
|
||||
messageMethods?: Pick<MessageMethods, 'getMessages' | 'deleteMessages'>,
|
||||
): ConversationMethods {
|
||||
function getMessageMethods() {
|
||||
if (!messageMethods) {
|
||||
throw new Error('Message methods not injected into conversation methods');
|
||||
}
|
||||
return messageMethods;
|
||||
}
|
||||
|
||||
/**
|
||||
* Searches for a conversation by conversationId and returns a lean document with only conversationId and user.
|
||||
*/
|
||||
async function searchConversation(conversationId: string) {
|
||||
try {
|
||||
const Conversation = mongoose.models.Conversation as Model<IConversation>;
|
||||
return await Conversation.findOne({ conversationId }, 'conversationId user').lean();
|
||||
} catch (error) {
|
||||
logger.error('[searchConversation] Error searching conversation', error);
|
||||
throw new Error('Error searching conversation');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a single conversation for a given user and conversation ID.
|
||||
*/
|
||||
async function getConvo(user: string, conversationId: string) {
|
||||
try {
|
||||
const Conversation = mongoose.models.Conversation as Model<IConversation>;
|
||||
return await Conversation.findOne({ user, conversationId }).lean();
|
||||
} catch (error) {
|
||||
logger.error('[getConvo] Error getting single conversation', error);
|
||||
throw new Error('Error getting single conversation');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes conversations and messages with null or empty IDs.
|
||||
*/
|
||||
async function deleteNullOrEmptyConversations() {
|
||||
try {
|
||||
const Conversation = mongoose.models.Conversation as Model<IConversation>;
|
||||
const { deleteMessages } = getMessageMethods();
|
||||
const filter = {
|
||||
$or: [
|
||||
{ conversationId: null },
|
||||
{ conversationId: '' },
|
||||
{ conversationId: { $exists: false } },
|
||||
],
|
||||
};
|
||||
|
||||
const result = await Conversation.deleteMany(filter);
|
||||
const messageDeleteResult = await deleteMessages(filter);
|
||||
|
||||
logger.info(
|
||||
`[deleteNullOrEmptyConversations] Deleted ${result.deletedCount} conversations and ${messageDeleteResult.deletedCount} messages`,
|
||||
);
|
||||
|
||||
return {
|
||||
conversations: result,
|
||||
messages: messageDeleteResult,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[deleteNullOrEmptyConversations] Error deleting conversations', error);
|
||||
throw new Error('Error deleting conversations with null or empty conversationId');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Searches for a conversation by conversationId and returns associated file ids.
|
||||
*/
|
||||
async function getConvoFiles(conversationId: string): Promise<string[]> {
|
||||
try {
|
||||
const Conversation = mongoose.models.Conversation as Model<IConversation>;
|
||||
return (
|
||||
((await Conversation.findOne({ conversationId }, 'files').lean()) as IConversation | null)
|
||||
?.files ?? []
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('[getConvoFiles] Error getting conversation files', error);
|
||||
throw new Error('Error getting conversation files');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves a conversation to the database.
|
||||
*/
|
||||
async function saveConvo(
|
||||
{
|
||||
userId,
|
||||
isTemporary,
|
||||
interfaceConfig,
|
||||
}: {
|
||||
userId: string;
|
||||
isTemporary?: boolean;
|
||||
interfaceConfig?: AppConfig['interfaceConfig'];
|
||||
},
|
||||
{
|
||||
conversationId,
|
||||
newConversationId,
|
||||
...convo
|
||||
}: {
|
||||
conversationId: string;
|
||||
newConversationId?: string;
|
||||
[key: string]: unknown;
|
||||
},
|
||||
metadata?: { context?: string; unsetFields?: Record<string, number>; noUpsert?: boolean },
|
||||
) {
|
||||
try {
|
||||
const Conversation = mongoose.models.Conversation as Model<IConversation>;
|
||||
const { getMessages } = getMessageMethods();
|
||||
|
||||
if (metadata?.context) {
|
||||
logger.debug(`[saveConvo] ${metadata.context}`);
|
||||
}
|
||||
|
||||
const messages = await getMessages({ conversationId }, '_id');
|
||||
const update: Record<string, unknown> = { ...convo, messages, user: userId };
|
||||
|
||||
if (newConversationId) {
|
||||
update.conversationId = newConversationId;
|
||||
}
|
||||
|
||||
if (isTemporary) {
|
||||
try {
|
||||
update.expiredAt = createTempChatExpirationDate(interfaceConfig);
|
||||
} catch (err) {
|
||||
logger.error('Error creating temporary chat expiration date:', err);
|
||||
logger.info(`---\`saveConvo\` context: ${metadata?.context}`);
|
||||
update.expiredAt = null;
|
||||
}
|
||||
} else {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
const updateOperation: Record<string, unknown> = { $set: update };
|
||||
if (metadata?.unsetFields && Object.keys(metadata.unsetFields).length > 0) {
|
||||
updateOperation.$unset = metadata.unsetFields;
|
||||
}
|
||||
|
||||
const conversation = await Conversation.findOneAndUpdate(
|
||||
{ conversationId, user: userId },
|
||||
updateOperation,
|
||||
{
|
||||
new: true,
|
||||
upsert: metadata?.noUpsert !== true,
|
||||
},
|
||||
);
|
||||
|
||||
if (!conversation) {
|
||||
logger.debug('[saveConvo] Conversation not found, skipping update');
|
||||
return null;
|
||||
}
|
||||
|
||||
return conversation.toObject();
|
||||
} catch (error) {
|
||||
logger.error('[saveConvo] Error saving conversation', error);
|
||||
if (metadata?.context) {
|
||||
logger.info(`[saveConvo] ${metadata.context}`);
|
||||
}
|
||||
return { message: 'Error saving conversation' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves multiple conversations in bulk.
|
||||
*/
|
||||
async function bulkSaveConvos(conversations: Array<Record<string, unknown>>) {
|
||||
try {
|
||||
const Conversation = mongoose.models.Conversation as Model<IConversation>;
|
||||
const bulkOps = conversations.map((convo) => ({
|
||||
updateOne: {
|
||||
filter: { conversationId: convo.conversationId, user: convo.user },
|
||||
update: convo,
|
||||
upsert: true,
|
||||
timestamps: false,
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await Conversation.bulkWrite(bulkOps);
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[bulkSaveConvos] Error saving conversations in bulk', error);
|
||||
throw new Error('Failed to save conversations in bulk.');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves conversations using cursor-based pagination.
|
||||
*/
|
||||
async function getConvosByCursor(
|
||||
user: string,
|
||||
{
|
||||
cursor,
|
||||
limit = 25,
|
||||
isArchived = false,
|
||||
tags,
|
||||
search,
|
||||
sortBy = 'updatedAt',
|
||||
sortDirection = 'desc',
|
||||
}: {
|
||||
cursor?: string | null;
|
||||
limit?: number;
|
||||
isArchived?: boolean;
|
||||
tags?: string[];
|
||||
search?: string;
|
||||
sortBy?: string;
|
||||
sortDirection?: string;
|
||||
} = {},
|
||||
) {
|
||||
const Conversation = mongoose.models.Conversation as Model<IConversation>;
|
||||
const filters: FilterQuery<IConversation>[] = [{ user } as FilterQuery<IConversation>];
|
||||
if (isArchived) {
|
||||
filters.push({ isArchived: true } as FilterQuery<IConversation>);
|
||||
} else {
|
||||
filters.push({
|
||||
$or: [{ isArchived: false }, { isArchived: { $exists: false } }],
|
||||
} as FilterQuery<IConversation>);
|
||||
}
|
||||
|
||||
if (Array.isArray(tags) && tags.length > 0) {
|
||||
filters.push({ tags: { $in: tags } } as FilterQuery<IConversation>);
|
||||
}
|
||||
|
||||
filters.push({
|
||||
$or: [{ expiredAt: null }, { expiredAt: { $exists: false } }],
|
||||
} as FilterQuery<IConversation>);
|
||||
|
||||
if (search) {
|
||||
try {
|
||||
const meiliResults = await (
|
||||
Conversation as unknown as {
|
||||
meiliSearch: (
|
||||
query: string,
|
||||
options: Record<string, string>,
|
||||
) => Promise<{
|
||||
hits: Array<{ conversationId: string }>;
|
||||
}>;
|
||||
}
|
||||
).meiliSearch(search, { filter: `user = "${user}"` });
|
||||
const matchingIds = Array.isArray(meiliResults.hits)
|
||||
? meiliResults.hits.map((result) => result.conversationId)
|
||||
: [];
|
||||
if (!matchingIds.length) {
|
||||
return { conversations: [], nextCursor: null };
|
||||
}
|
||||
filters.push({ conversationId: { $in: matchingIds } } as FilterQuery<IConversation>);
|
||||
} catch (error) {
|
||||
logger.error('[getConvosByCursor] Error during meiliSearch', error);
|
||||
throw new Error('Error during meiliSearch');
|
||||
}
|
||||
}
|
||||
|
||||
const validSortFields = ['title', 'createdAt', 'updatedAt'];
|
||||
if (!validSortFields.includes(sortBy)) {
|
||||
throw new Error(
|
||||
`Invalid sortBy field: ${sortBy}. Must be one of ${validSortFields.join(', ')}`,
|
||||
);
|
||||
}
|
||||
const finalSortBy = sortBy;
|
||||
const finalSortDirection = sortDirection === 'asc' ? 'asc' : 'desc';
|
||||
|
||||
let cursorFilter: FilterQuery<IConversation> | null = null;
|
||||
if (cursor) {
|
||||
try {
|
||||
const decoded = JSON.parse(Buffer.from(cursor, 'base64').toString());
|
||||
const { primary, secondary } = decoded;
|
||||
const primaryValue = finalSortBy === 'title' ? primary : new Date(primary);
|
||||
const secondaryValue = new Date(secondary);
|
||||
const op = finalSortDirection === 'asc' ? '$gt' : '$lt';
|
||||
|
||||
cursorFilter = {
|
||||
$or: [
|
||||
{ [finalSortBy]: { [op]: primaryValue } },
|
||||
{
|
||||
[finalSortBy]: primaryValue,
|
||||
updatedAt: { [op]: secondaryValue },
|
||||
},
|
||||
],
|
||||
} as FilterQuery<IConversation>;
|
||||
} catch {
|
||||
logger.warn('[getConvosByCursor] Invalid cursor format, starting from beginning');
|
||||
}
|
||||
if (cursorFilter) {
|
||||
filters.push(cursorFilter);
|
||||
}
|
||||
}
|
||||
|
||||
const query: FilterQuery<IConversation> =
|
||||
filters.length === 1 ? filters[0] : ({ $and: filters } as FilterQuery<IConversation>);
|
||||
|
||||
try {
|
||||
const sortOrder: SortOrder = finalSortDirection === 'asc' ? 1 : -1;
|
||||
const sortObj: Record<string, SortOrder> = { [finalSortBy]: sortOrder };
|
||||
|
||||
if (finalSortBy !== 'updatedAt') {
|
||||
sortObj.updatedAt = sortOrder;
|
||||
}
|
||||
|
||||
const convos = await Conversation.find(query)
|
||||
.select(
|
||||
'conversationId endpoint title createdAt updatedAt user model agent_id assistant_id spec iconURL',
|
||||
)
|
||||
.sort(sortObj)
|
||||
.limit(limit + 1)
|
||||
.lean();
|
||||
|
||||
let nextCursor: string | null = null;
|
||||
if (convos.length > limit) {
|
||||
convos.pop();
|
||||
const lastReturned = convos[convos.length - 1] as Record<string, unknown>;
|
||||
const primaryValue = lastReturned[finalSortBy];
|
||||
const primaryStr =
|
||||
finalSortBy === 'title' ? primaryValue : (primaryValue as Date).toISOString();
|
||||
const secondaryStr = (lastReturned.updatedAt as Date).toISOString();
|
||||
const composite = { primary: primaryStr, secondary: secondaryStr };
|
||||
nextCursor = Buffer.from(JSON.stringify(composite)).toString('base64');
|
||||
}
|
||||
|
||||
return { conversations: convos, nextCursor };
|
||||
} catch (error) {
|
||||
logger.error('[getConvosByCursor] Error getting conversations', error);
|
||||
throw new Error('Error getting conversations');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches specific conversations by ID array with pagination.
|
||||
*/
|
||||
async function getConvosQueried(
|
||||
user: string,
|
||||
convoIds: Array<{ conversationId: string }> | null,
|
||||
cursor: string | null = null,
|
||||
limit = 25,
|
||||
) {
|
||||
try {
|
||||
const Conversation = mongoose.models.Conversation as Model<IConversation>;
|
||||
if (!convoIds?.length) {
|
||||
return { conversations: [], nextCursor: null, convoMap: {} };
|
||||
}
|
||||
|
||||
const conversationIds = convoIds.map((convo) => convo.conversationId);
|
||||
|
||||
const results = await Conversation.find({
|
||||
user,
|
||||
conversationId: { $in: conversationIds },
|
||||
$or: [{ expiredAt: { $exists: false } }, { expiredAt: null }],
|
||||
}).lean();
|
||||
|
||||
results.sort(
|
||||
(a, b) => new Date(b.updatedAt ?? 0).getTime() - new Date(a.updatedAt ?? 0).getTime(),
|
||||
);
|
||||
|
||||
let filtered = results;
|
||||
if (cursor && cursor !== 'start') {
|
||||
const cursorDate = new Date(cursor);
|
||||
filtered = results.filter((convo) => new Date(convo.updatedAt ?? 0) < cursorDate);
|
||||
}
|
||||
|
||||
const limited = filtered.slice(0, limit + 1);
|
||||
let nextCursor: string | null = null;
|
||||
if (limited.length > limit) {
|
||||
limited.pop();
|
||||
nextCursor = (limited[limited.length - 1].updatedAt as Date).toISOString();
|
||||
}
|
||||
|
||||
const convoMap: Record<string, unknown> = {};
|
||||
limited.forEach((convo) => {
|
||||
convoMap[convo.conversationId] = convo;
|
||||
});
|
||||
|
||||
return { conversations: limited, nextCursor, convoMap };
|
||||
} catch (error) {
|
||||
logger.error('[getConvosQueried] Error getting conversations', error);
|
||||
throw new Error('Error fetching conversations');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets conversation title, returning 'New Chat' as default.
|
||||
*/
|
||||
async function getConvoTitle(user: string, conversationId: string) {
|
||||
try {
|
||||
const convo = await getConvo(user, conversationId);
|
||||
if (convo && !convo.title) {
|
||||
return null;
|
||||
} else {
|
||||
return convo?.title || 'New Chat';
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[getConvoTitle] Error getting conversation title', error);
|
||||
throw new Error('Error getting conversation title');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes conversations and their associated messages for a given user and filter.
|
||||
*/
|
||||
async function deleteConvos(user: string, filter: FilterQuery<IConversation>) {
|
||||
try {
|
||||
const Conversation = mongoose.models.Conversation as Model<IConversation>;
|
||||
const { deleteMessages } = getMessageMethods();
|
||||
const userFilter = { ...filter, user };
|
||||
const conversations = await Conversation.find(userFilter).select('conversationId');
|
||||
const conversationIds = conversations.map((c) => c.conversationId);
|
||||
|
||||
if (!conversationIds.length) {
|
||||
throw new Error('Conversation not found or already deleted.');
|
||||
}
|
||||
|
||||
const deleteConvoResult = await Conversation.deleteMany(userFilter);
|
||||
|
||||
const deleteMessagesResult = await deleteMessages({
|
||||
conversationId: { $in: conversationIds },
|
||||
});
|
||||
|
||||
return { ...deleteConvoResult, messages: deleteMessagesResult };
|
||||
} catch (error) {
|
||||
logger.error('[deleteConvos] Error deleting conversations and messages', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
getConvoFiles,
|
||||
searchConversation,
|
||||
deleteNullOrEmptyConversations,
|
||||
saveConvo,
|
||||
bulkSaveConvos,
|
||||
getConvosByCursor,
|
||||
getConvosQueried,
|
||||
getConvo,
|
||||
getConvoTitle,
|
||||
deleteConvos,
|
||||
};
|
||||
}
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
import mongoose from 'mongoose';
|
||||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import { createConversationTagMethods } from './conversationTag';
|
||||
import { createModels } from '~/models';
|
||||
import type { IConversationTag } from '~/schema/conversationTag';
|
||||
import type { IConversation } from '..';
|
||||
|
||||
jest.mock('~/config/winston', () => ({
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
info: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
}));
|
||||
|
||||
let mongoServer: InstanceType<typeof MongoMemoryServer>;
|
||||
let ConversationTag: mongoose.Model<IConversationTag>;
|
||||
let Conversation: mongoose.Model<IConversation>;
|
||||
let deleteConversationTag: ReturnType<typeof createConversationTagMethods>['deleteConversationTag'];
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
|
||||
// Register models
|
||||
const models = createModels(mongoose);
|
||||
Object.assign(mongoose.models, models);
|
||||
|
||||
ConversationTag = mongoose.models.ConversationTag;
|
||||
Conversation = mongoose.models.Conversation;
|
||||
|
||||
// Create methods from factory
|
||||
const methods = createConversationTagMethods(mongoose);
|
||||
deleteConversationTag = methods.deleteConversationTag;
|
||||
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await ConversationTag.deleteMany({});
|
||||
await Conversation.deleteMany({});
|
||||
});
|
||||
|
||||
describe('ConversationTag model - $pullAll operations', () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
|
||||
describe('deleteConversationTag', () => {
|
||||
it('should remove the tag from all conversations that have it', async () => {
|
||||
await ConversationTag.create({ tag: 'work', user: userId, position: 1 });
|
||||
|
||||
await Conversation.create([
|
||||
{ conversationId: 'conv1', user: userId, endpoint: 'openAI', tags: ['work', 'important'] },
|
||||
{ conversationId: 'conv2', user: userId, endpoint: 'openAI', tags: ['work'] },
|
||||
{ conversationId: 'conv3', user: userId, endpoint: 'openAI', tags: ['personal'] },
|
||||
]);
|
||||
|
||||
await deleteConversationTag(userId, 'work');
|
||||
|
||||
const convos = await Conversation.find({ user: userId }).sort({ conversationId: 1 }).lean();
|
||||
expect(convos[0].tags).toEqual(['important']);
|
||||
expect(convos[1].tags).toEqual([]);
|
||||
expect(convos[2].tags).toEqual(['personal']);
|
||||
});
|
||||
|
||||
it('should delete the tag document itself', async () => {
|
||||
await ConversationTag.create({ tag: 'temp', user: userId, position: 1 });
|
||||
|
||||
const result = await deleteConversationTag(userId, 'temp');
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result!.tag).toBe('temp');
|
||||
|
||||
const remaining = await ConversationTag.find({ user: userId }).lean();
|
||||
expect(remaining).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should return null when the tag does not exist', async () => {
|
||||
const result = await deleteConversationTag(userId, 'nonexistent');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should adjust positions of tags after the deleted one', async () => {
|
||||
await ConversationTag.create([
|
||||
{ tag: 'first', user: userId, position: 1 },
|
||||
{ tag: 'second', user: userId, position: 2 },
|
||||
{ tag: 'third', user: userId, position: 3 },
|
||||
]);
|
||||
|
||||
await deleteConversationTag(userId, 'first');
|
||||
|
||||
const tags = await ConversationTag.find({ user: userId }).sort({ position: 1 }).lean();
|
||||
expect(tags).toHaveLength(2);
|
||||
expect(tags[0].tag).toBe('second');
|
||||
expect(tags[0].position).toBe(1);
|
||||
expect(tags[1].tag).toBe('third');
|
||||
expect(tags[1].position).toBe(2);
|
||||
});
|
||||
|
||||
it('should not affect conversations of other users', async () => {
|
||||
const otherUser = new mongoose.Types.ObjectId().toString();
|
||||
|
||||
await ConversationTag.create({ tag: 'shared-name', user: userId, position: 1 });
|
||||
await ConversationTag.create({ tag: 'shared-name', user: otherUser, position: 1 });
|
||||
|
||||
await Conversation.create([
|
||||
{ conversationId: 'mine', user: userId, endpoint: 'openAI', tags: ['shared-name'] },
|
||||
{ conversationId: 'theirs', user: otherUser, endpoint: 'openAI', tags: ['shared-name'] },
|
||||
]);
|
||||
|
||||
await deleteConversationTag(userId, 'shared-name');
|
||||
|
||||
const myConvo = await Conversation.findOne({ conversationId: 'mine' }).lean();
|
||||
const theirConvo = await Conversation.findOne({ conversationId: 'theirs' }).lean();
|
||||
|
||||
expect(myConvo?.tags).toEqual([]);
|
||||
expect(theirConvo?.tags).toEqual(['shared-name']);
|
||||
});
|
||||
|
||||
it('should handle duplicate tags in conversations correctly', async () => {
|
||||
await ConversationTag.create({ tag: 'dup', user: userId, position: 1 });
|
||||
|
||||
const conv = await Conversation.create({
|
||||
conversationId: 'conv-dup',
|
||||
user: userId,
|
||||
endpoint: 'openAI',
|
||||
tags: ['dup', 'other', 'dup'],
|
||||
});
|
||||
|
||||
await deleteConversationTag(userId, 'dup');
|
||||
|
||||
const updated = await Conversation.findById(conv._id).lean();
|
||||
expect(updated?.tags).toEqual(['other']);
|
||||
});
|
||||
});
|
||||
});
|
||||
312
packages/data-schemas/src/methods/conversationTag.ts
Normal file
312
packages/data-schemas/src/methods/conversationTag.ts
Normal file
|
|
@ -0,0 +1,312 @@
|
|||
import type { Model } from 'mongoose';
|
||||
import logger from '~/config/winston';
|
||||
|
||||
interface IConversationTag {
|
||||
user: string;
|
||||
tag: string;
|
||||
description?: string;
|
||||
position: number;
|
||||
count: number;
|
||||
createdAt?: Date;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export function createConversationTagMethods(mongoose: typeof import('mongoose')) {
|
||||
/**
|
||||
* Retrieves all conversation tags for a user.
|
||||
*/
|
||||
async function getConversationTags(user: string) {
|
||||
try {
|
||||
const ConversationTag = mongoose.models.ConversationTag as Model<IConversationTag>;
|
||||
return await ConversationTag.find({ user }).sort({ position: 1 }).lean();
|
||||
} catch (error) {
|
||||
logger.error('[getConversationTags] Error getting conversation tags', error);
|
||||
throw new Error('Error getting conversation tags');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new conversation tag.
|
||||
*/
|
||||
async function createConversationTag(
|
||||
user: string,
|
||||
data: {
|
||||
tag: string;
|
||||
description?: string;
|
||||
addToConversation?: boolean;
|
||||
conversationId?: string;
|
||||
},
|
||||
) {
|
||||
try {
|
||||
const ConversationTag = mongoose.models.ConversationTag as Model<IConversationTag>;
|
||||
const Conversation = mongoose.models.Conversation;
|
||||
const { tag, description, addToConversation, conversationId } = data;
|
||||
|
||||
const existingTag = await ConversationTag.findOne({ user, tag }).lean();
|
||||
if (existingTag) {
|
||||
return existingTag;
|
||||
}
|
||||
|
||||
const maxPosition = await ConversationTag.findOne({ user }).sort('-position').lean();
|
||||
const position = (maxPosition?.position || 0) + 1;
|
||||
|
||||
const newTag = await ConversationTag.findOneAndUpdate(
|
||||
{ tag, user },
|
||||
{
|
||||
tag,
|
||||
user,
|
||||
count: addToConversation ? 1 : 0,
|
||||
position,
|
||||
description,
|
||||
$setOnInsert: { createdAt: new Date() },
|
||||
},
|
||||
{
|
||||
new: true,
|
||||
upsert: true,
|
||||
lean: true,
|
||||
},
|
||||
);
|
||||
|
||||
if (addToConversation && conversationId) {
|
||||
await Conversation.findOneAndUpdate(
|
||||
{ user, conversationId },
|
||||
{ $addToSet: { tags: tag } },
|
||||
{ new: true },
|
||||
);
|
||||
}
|
||||
|
||||
return newTag;
|
||||
} catch (error) {
|
||||
logger.error('[createConversationTag] Error creating conversation tag', error);
|
||||
throw new Error('Error creating conversation tag');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adjusts positions of tags when a tag's position is changed.
|
||||
*/
|
||||
async function adjustPositions(user: string, oldPosition: number, newPosition: number) {
|
||||
if (oldPosition === newPosition) {
|
||||
return;
|
||||
}
|
||||
|
||||
const ConversationTag = mongoose.models.ConversationTag as Model<IConversationTag>;
|
||||
|
||||
const update =
|
||||
oldPosition < newPosition ? { $inc: { position: -1 } } : { $inc: { position: 1 } };
|
||||
const position =
|
||||
oldPosition < newPosition
|
||||
? {
|
||||
$gt: Math.min(oldPosition, newPosition),
|
||||
$lte: Math.max(oldPosition, newPosition),
|
||||
}
|
||||
: {
|
||||
$gte: Math.min(oldPosition, newPosition),
|
||||
$lt: Math.max(oldPosition, newPosition),
|
||||
};
|
||||
|
||||
await ConversationTag.updateMany({ user, position }, update);
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates an existing conversation tag.
|
||||
*/
|
||||
async function updateConversationTag(
|
||||
user: string,
|
||||
oldTag: string,
|
||||
data: { tag?: string; description?: string; position?: number },
|
||||
) {
|
||||
try {
|
||||
const ConversationTag = mongoose.models.ConversationTag as Model<IConversationTag>;
|
||||
const Conversation = mongoose.models.Conversation;
|
||||
const { tag: newTag, description, position } = data;
|
||||
|
||||
const existingTag = await ConversationTag.findOne({ user, tag: oldTag }).lean();
|
||||
if (!existingTag) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (newTag && newTag !== oldTag) {
|
||||
const tagAlreadyExists = await ConversationTag.findOne({ user, tag: newTag }).lean();
|
||||
if (tagAlreadyExists) {
|
||||
throw new Error('Tag already exists');
|
||||
}
|
||||
|
||||
await Conversation.updateMany({ user, tags: oldTag }, { $set: { 'tags.$': newTag } });
|
||||
}
|
||||
|
||||
const updateData: Record<string, unknown> = {};
|
||||
if (newTag) {
|
||||
updateData.tag = newTag;
|
||||
}
|
||||
if (description !== undefined) {
|
||||
updateData.description = description;
|
||||
}
|
||||
if (position !== undefined) {
|
||||
await adjustPositions(user, existingTag.position, position);
|
||||
updateData.position = position;
|
||||
}
|
||||
|
||||
return await ConversationTag.findOneAndUpdate({ user, tag: oldTag }, updateData, {
|
||||
new: true,
|
||||
lean: true,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[updateConversationTag] Error updating conversation tag', error);
|
||||
throw new Error('Error updating conversation tag');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a conversation tag.
|
||||
*/
|
||||
async function deleteConversationTag(user: string, tag: string) {
|
||||
try {
|
||||
const ConversationTag = mongoose.models.ConversationTag as Model<IConversationTag>;
|
||||
const Conversation = mongoose.models.Conversation;
|
||||
|
||||
const deletedTag = await ConversationTag.findOneAndDelete({ user, tag }).lean();
|
||||
if (!deletedTag) {
|
||||
return null;
|
||||
}
|
||||
|
||||
await Conversation.updateMany({ user, tags: tag }, { $pullAll: { tags: [tag] } });
|
||||
|
||||
await ConversationTag.updateMany(
|
||||
{ user, position: { $gt: deletedTag.position } },
|
||||
{ $inc: { position: -1 } },
|
||||
);
|
||||
|
||||
return deletedTag;
|
||||
} catch (error) {
|
||||
logger.error('[deleteConversationTag] Error deleting conversation tag', error);
|
||||
throw new Error('Error deleting conversation tag');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates tags for a specific conversation.
|
||||
*/
|
||||
async function updateTagsForConversation(user: string, conversationId: string, tags: string[]) {
|
||||
try {
|
||||
const ConversationTag = mongoose.models.ConversationTag as Model<IConversationTag>;
|
||||
const Conversation = mongoose.models.Conversation;
|
||||
|
||||
const conversation = await Conversation.findOne({ user, conversationId }).lean();
|
||||
if (!conversation) {
|
||||
throw new Error('Conversation not found');
|
||||
}
|
||||
|
||||
const oldTags = new Set<string>(
|
||||
((conversation as Record<string, unknown>).tags as string[]) ?? [],
|
||||
);
|
||||
const newTags = new Set(tags);
|
||||
|
||||
const addedTags = [...newTags].filter((tag) => !oldTags.has(tag));
|
||||
const removedTags = [...oldTags].filter((tag) => !newTags.has(tag));
|
||||
|
||||
const bulkOps: Array<{
|
||||
updateOne: {
|
||||
filter: Record<string, unknown>;
|
||||
update: Record<string, unknown>;
|
||||
upsert?: boolean;
|
||||
};
|
||||
}> = [];
|
||||
|
||||
for (const tag of addedTags) {
|
||||
bulkOps.push({
|
||||
updateOne: {
|
||||
filter: { user, tag },
|
||||
update: { $inc: { count: 1 } },
|
||||
upsert: true,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
for (const tag of removedTags) {
|
||||
bulkOps.push({
|
||||
updateOne: {
|
||||
filter: { user, tag },
|
||||
update: { $inc: { count: -1 } },
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (bulkOps.length > 0) {
|
||||
await ConversationTag.bulkWrite(bulkOps);
|
||||
}
|
||||
|
||||
const updatedConversation = (
|
||||
await Conversation.findOneAndUpdate(
|
||||
{ user, conversationId },
|
||||
{ $set: { tags: [...newTags] } },
|
||||
{ new: true },
|
||||
)
|
||||
).toObject();
|
||||
|
||||
return updatedConversation.tags;
|
||||
} catch (error) {
|
||||
logger.error('[updateTagsForConversation] Error updating tags', error);
|
||||
throw new Error('Error updating tags for conversation');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Increments tag counts for existing tags only.
|
||||
*/
|
||||
async function bulkIncrementTagCounts(user: string, tags: string[]) {
|
||||
if (!tags || tags.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const ConversationTag = mongoose.models.ConversationTag as Model<IConversationTag>;
|
||||
const uniqueTags = [...new Set(tags.filter(Boolean))];
|
||||
if (uniqueTags.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bulkOps = uniqueTags.map((tag) => ({
|
||||
updateOne: {
|
||||
filter: { user, tag },
|
||||
update: { $inc: { count: 1 } },
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await ConversationTag.bulkWrite(bulkOps);
|
||||
if (result && result.modifiedCount > 0) {
|
||||
logger.debug(
|
||||
`user: ${user} | Incremented tag counts - modified ${result.modifiedCount} tags`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[bulkIncrementTagCounts] Error incrementing tag counts', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes all conversation tags matching the given filter.
|
||||
*/
|
||||
async function deleteConversationTags(filter: Record<string, unknown>): Promise<number> {
|
||||
try {
|
||||
const ConversationTag = mongoose.models.ConversationTag as Model<IConversationTag>;
|
||||
const result = await ConversationTag.deleteMany(filter);
|
||||
return result.deletedCount;
|
||||
} catch (error) {
|
||||
logger.error('[deleteConversationTags] Error deleting conversation tags', error);
|
||||
throw new Error('Error deleting conversation tags');
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
getConversationTags,
|
||||
createConversationTag,
|
||||
updateConversationTag,
|
||||
deleteConversationTag,
|
||||
deleteConversationTags,
|
||||
bulkIncrementTagCounts,
|
||||
updateTagsForConversation,
|
||||
};
|
||||
}
|
||||
|
||||
export type ConversationTagMethods = ReturnType<typeof createConversationTagMethods>;
|
||||
297
packages/data-schemas/src/methods/convoStructure.spec.ts
Normal file
297
packages/data-schemas/src/methods/convoStructure.spec.ts
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
import mongoose from 'mongoose';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import { buildTree } from 'librechat-data-provider';
|
||||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import { createModels } from '~/models';
|
||||
import { createMessageMethods } from './message';
|
||||
import type { IMessage } from '..';
|
||||
|
||||
jest.mock('~/config/winston', () => ({
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
info: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
}));
|
||||
|
||||
let mongod: InstanceType<typeof MongoMemoryServer>;
|
||||
let Message: mongoose.Model<IMessage>;
|
||||
let getMessages: ReturnType<typeof createMessageMethods>['getMessages'];
|
||||
let bulkSaveMessages: ReturnType<typeof createMessageMethods>['bulkSaveMessages'];
|
||||
|
||||
beforeAll(async () => {
|
||||
mongod = await MongoMemoryServer.create();
|
||||
const uri = mongod.getUri();
|
||||
|
||||
const models = createModels(mongoose);
|
||||
Object.assign(mongoose.models, models);
|
||||
Message = mongoose.models.Message;
|
||||
|
||||
const methods = createMessageMethods(mongoose);
|
||||
getMessages = methods.getMessages;
|
||||
bulkSaveMessages = methods.bulkSaveMessages;
|
||||
|
||||
await mongoose.connect(uri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongod.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await Message.deleteMany({});
|
||||
});
|
||||
|
||||
describe('Conversation Structure Tests', () => {
|
||||
test('Conversation folding/corrupting with inconsistent timestamps', async () => {
|
||||
const userId = 'testUser';
|
||||
const conversationId = 'testConversation';
|
||||
|
||||
// Create messages with inconsistent timestamps
|
||||
const messages = [
|
||||
{
|
||||
messageId: 'message0',
|
||||
parentMessageId: null,
|
||||
text: 'Message 0',
|
||||
createdAt: new Date('2023-01-01T00:00:00Z'),
|
||||
},
|
||||
{
|
||||
messageId: 'message1',
|
||||
parentMessageId: 'message0',
|
||||
text: 'Message 1',
|
||||
createdAt: new Date('2023-01-01T00:02:00Z'),
|
||||
},
|
||||
{
|
||||
messageId: 'message2',
|
||||
parentMessageId: 'message1',
|
||||
text: 'Message 2',
|
||||
createdAt: new Date('2023-01-01T00:01:00Z'),
|
||||
}, // Note: Earlier than its parent
|
||||
{
|
||||
messageId: 'message3',
|
||||
parentMessageId: 'message1',
|
||||
text: 'Message 3',
|
||||
createdAt: new Date('2023-01-01T00:03:00Z'),
|
||||
},
|
||||
{
|
||||
messageId: 'message4',
|
||||
parentMessageId: 'message2',
|
||||
text: 'Message 4',
|
||||
createdAt: new Date('2023-01-01T00:04:00Z'),
|
||||
},
|
||||
];
|
||||
|
||||
// Add common properties to all messages
|
||||
messages.forEach((msg) => {
|
||||
Object.assign(msg, {
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: false,
|
||||
error: false,
|
||||
unfinished: false,
|
||||
});
|
||||
});
|
||||
|
||||
// Save messages with overrideTimestamp omitted (default is false)
|
||||
await bulkSaveMessages(messages, true);
|
||||
|
||||
// Retrieve messages (this will sort by createdAt)
|
||||
const retrievedMessages = await getMessages({ conversationId, user: userId });
|
||||
|
||||
// Build tree
|
||||
const tree = buildTree({ messages: retrievedMessages as TMessage[] });
|
||||
|
||||
// Check if the tree is incorrect (folded/corrupted)
|
||||
expect(tree!.length).toBeGreaterThan(1); // Should have multiple root messages, indicating corruption
|
||||
});
|
||||
|
||||
test('Fix: Conversation structure maintained with more than 16 messages', async () => {
|
||||
const userId = 'testUser';
|
||||
const conversationId = 'testConversation';
|
||||
|
||||
// Create more than 16 messages
|
||||
const messages = Array.from({ length: 20 }, (_, i) => ({
|
||||
messageId: `message${i}`,
|
||||
parentMessageId: i === 0 ? null : `message${i - 1}`,
|
||||
conversationId,
|
||||
user: userId,
|
||||
text: `Message ${i}`,
|
||||
createdAt: new Date(Date.now() + (i % 2 === 0 ? i * 500000 : -i * 500000)),
|
||||
}));
|
||||
|
||||
// Save messages with new timestamps being generated (message objects ignored)
|
||||
await bulkSaveMessages(messages);
|
||||
|
||||
// Retrieve messages (this will sort by createdAt, but it shouldn't matter now)
|
||||
const retrievedMessages = await getMessages({ conversationId, user: userId });
|
||||
|
||||
// Build tree
|
||||
const tree = buildTree({ messages: retrievedMessages as TMessage[] });
|
||||
|
||||
// Check if the tree is correct
|
||||
expect(tree!.length).toBe(1); // Should have only one root message
|
||||
let currentNode = tree![0];
|
||||
for (let i = 1; i < 20; i++) {
|
||||
expect(currentNode.children!.length).toBe(1);
|
||||
currentNode = currentNode.children![0];
|
||||
expect(currentNode.text).toBe(`Message ${i}`);
|
||||
}
|
||||
expect(currentNode.children!.length).toBe(0); // Last message should have no children
|
||||
});
|
||||
|
||||
test('Simulate MongoDB ordering issue with more than 16 messages and close timestamps', async () => {
|
||||
const userId = 'testUser';
|
||||
const conversationId = 'testConversation';
|
||||
|
||||
// Create more than 16 messages with very close timestamps
|
||||
const messages = Array.from({ length: 20 }, (_, i) => ({
|
||||
messageId: `message${i}`,
|
||||
parentMessageId: i === 0 ? null : `message${i - 1}`,
|
||||
conversationId,
|
||||
user: userId,
|
||||
text: `Message ${i}`,
|
||||
createdAt: new Date(Date.now() + (i % 2 === 0 ? i * 1 : -i * 1)),
|
||||
}));
|
||||
|
||||
// Add common properties to all messages
|
||||
messages.forEach((msg) => {
|
||||
Object.assign(msg, { isCreatedByUser: false, error: false, unfinished: false });
|
||||
});
|
||||
|
||||
await bulkSaveMessages(messages, true);
|
||||
const retrievedMessages = await getMessages({ conversationId, user: userId });
|
||||
const tree = buildTree({ messages: retrievedMessages as TMessage[] });
|
||||
expect(tree!.length).toBeGreaterThan(1);
|
||||
});
|
||||
|
||||
test('Fix: Preserve order with more than 16 messages by maintaining original timestamps', async () => {
|
||||
const userId = 'testUser';
|
||||
const conversationId = 'testConversation';
|
||||
|
||||
// Create more than 16 messages with distinct timestamps
|
||||
const messages = Array.from({ length: 20 }, (_, i) => ({
|
||||
messageId: `message${i}`,
|
||||
parentMessageId: i === 0 ? null : `message${i - 1}`,
|
||||
conversationId,
|
||||
user: userId,
|
||||
text: `Message ${i}`,
|
||||
createdAt: new Date(Date.now() + i * 1000), // Ensure each message has a distinct timestamp
|
||||
}));
|
||||
|
||||
// Add common properties to all messages
|
||||
messages.forEach((msg) => {
|
||||
Object.assign(msg, { isCreatedByUser: false, error: false, unfinished: false });
|
||||
});
|
||||
|
||||
// Save messages with overriding timestamps (preserve original timestamps)
|
||||
await bulkSaveMessages(messages, true);
|
||||
|
||||
// Retrieve messages (this will sort by createdAt)
|
||||
const retrievedMessages = await getMessages({ conversationId, user: userId });
|
||||
|
||||
// Build tree
|
||||
const tree = buildTree({ messages: retrievedMessages as TMessage[] });
|
||||
|
||||
// Check if the tree is correct
|
||||
expect(tree!.length).toBe(1); // Should have only one root message
|
||||
let currentNode = tree![0];
|
||||
for (let i = 1; i < 20; i++) {
|
||||
expect(currentNode.children!.length).toBe(1);
|
||||
currentNode = currentNode.children![0];
|
||||
expect(currentNode.text).toBe(`Message ${i}`);
|
||||
}
|
||||
expect(currentNode.children!.length).toBe(0); // Last message should have no children
|
||||
});
|
||||
|
||||
test('Random order dates between parent and children messages', async () => {
|
||||
const userId = 'testUser';
|
||||
const conversationId = 'testConversation';
|
||||
|
||||
// Create messages with deliberately out-of-order timestamps but sequential creation
|
||||
const messages = [
|
||||
{
|
||||
messageId: 'parent',
|
||||
parentMessageId: null,
|
||||
text: 'Parent Message',
|
||||
createdAt: new Date('2023-01-01T00:00:00Z'), // Make parent earliest
|
||||
},
|
||||
{
|
||||
messageId: 'child1',
|
||||
parentMessageId: 'parent',
|
||||
text: 'Child Message 1',
|
||||
createdAt: new Date('2023-01-01T00:01:00Z'),
|
||||
},
|
||||
{
|
||||
messageId: 'child2',
|
||||
parentMessageId: 'parent',
|
||||
text: 'Child Message 2',
|
||||
createdAt: new Date('2023-01-01T00:02:00Z'),
|
||||
},
|
||||
{
|
||||
messageId: 'grandchild1',
|
||||
parentMessageId: 'child1',
|
||||
text: 'Grandchild Message 1',
|
||||
createdAt: new Date('2023-01-01T00:03:00Z'),
|
||||
},
|
||||
];
|
||||
|
||||
// Add common properties to all messages
|
||||
messages.forEach((msg) => {
|
||||
Object.assign(msg, {
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: false,
|
||||
error: false,
|
||||
unfinished: false,
|
||||
});
|
||||
});
|
||||
|
||||
// Save messages with overrideTimestamp set to true
|
||||
await bulkSaveMessages(messages, true);
|
||||
|
||||
// Retrieve messages
|
||||
const retrievedMessages = await getMessages({ conversationId, user: userId });
|
||||
|
||||
// Debug log to see what's being returned
|
||||
console.log(
|
||||
'Retrieved Messages:',
|
||||
retrievedMessages.map((msg) => ({
|
||||
messageId: msg.messageId,
|
||||
parentMessageId: msg.parentMessageId,
|
||||
createdAt: msg.createdAt,
|
||||
})),
|
||||
);
|
||||
|
||||
// Build tree
|
||||
const tree = buildTree({ messages: retrievedMessages as TMessage[] });
|
||||
|
||||
// Debug log to see the tree structure
|
||||
console.log(
|
||||
'Tree structure:',
|
||||
tree!.map((root) => ({
|
||||
messageId: root.messageId,
|
||||
children: root.children!.map((child) => ({
|
||||
messageId: child.messageId,
|
||||
children: child.children!.map((grandchild) => ({
|
||||
messageId: grandchild.messageId,
|
||||
})),
|
||||
})),
|
||||
})),
|
||||
);
|
||||
|
||||
// Verify the structure before making assertions
|
||||
expect(retrievedMessages.length).toBe(4); // Should have all 4 messages
|
||||
|
||||
// Check if messages are properly linked
|
||||
const parentMsg = retrievedMessages.find((msg) => msg.messageId === 'parent');
|
||||
expect(parentMsg!.parentMessageId).toBeNull(); // Parent should have null parentMessageId
|
||||
|
||||
const childMsg1 = retrievedMessages.find((msg) => msg.messageId === 'child1');
|
||||
expect(childMsg1!.parentMessageId).toBe('parent');
|
||||
|
||||
// Then check tree structure
|
||||
expect(tree!.length).toBe(1); // Should have only one root message
|
||||
expect(tree![0].messageId).toBe('parent');
|
||||
expect(tree![0].children!.length).toBe(2); // Should have two children
|
||||
});
|
||||
});
|
||||
405
packages/data-schemas/src/methods/file.acl.spec.ts
Normal file
405
packages/data-schemas/src/methods/file.acl.spec.ts
Normal file
|
|
@ -0,0 +1,405 @@
|
|||
import mongoose from 'mongoose';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import {
|
||||
ResourceType,
|
||||
AccessRoleIds,
|
||||
PrincipalType,
|
||||
PermissionBits,
|
||||
} from 'librechat-data-provider';
|
||||
import type { AccessRole as TAccessRole, AclEntry as TAclEntry } from '..';
|
||||
import type { Types } from 'mongoose';
|
||||
import { createAclEntryMethods } from './aclEntry';
|
||||
import { createModels } from '../models';
|
||||
import { createMethods } from './index';
|
||||
|
||||
/** Lean access role object from .lean() */
|
||||
type LeanAccessRole = TAccessRole & { _id: mongoose.Types.ObjectId };
|
||||
|
||||
/** Lean ACL entry from .lean() */
|
||||
type LeanAclEntry = TAclEntry & { _id: mongoose.Types.ObjectId };
|
||||
|
||||
/** Tool resources shape for agent file access */
|
||||
type AgentToolResources = {
|
||||
file_search?: { file_ids?: string[] };
|
||||
code_interpreter?: { file_ids?: string[] };
|
||||
};
|
||||
|
||||
let File: mongoose.Model<unknown>;
|
||||
let Agent: mongoose.Model<unknown>;
|
||||
let AclEntry: mongoose.Model<unknown>;
|
||||
let AccessRole: mongoose.Model<unknown>;
|
||||
let User: mongoose.Model<unknown>;
|
||||
let methods: ReturnType<typeof createMethods>;
|
||||
let aclMethods: ReturnType<typeof createAclEntryMethods>;
|
||||
|
||||
describe('File Access Control', () => {
|
||||
let mongoServer: MongoMemoryServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
createModels(mongoose);
|
||||
File = mongoose.models.File;
|
||||
Agent = mongoose.models.Agent;
|
||||
AclEntry = mongoose.models.AclEntry;
|
||||
AccessRole = mongoose.models.AccessRole;
|
||||
User = mongoose.models.User;
|
||||
|
||||
methods = createMethods(mongoose);
|
||||
aclMethods = createAclEntryMethods(mongoose);
|
||||
|
||||
// Seed default access roles
|
||||
await methods.seedDefaultRoles();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await File.deleteMany({});
|
||||
await Agent.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
await User.deleteMany({});
|
||||
});
|
||||
|
||||
describe('File ACL entry operations', () => {
|
||||
it('should create ACL entries for agent file access', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4(), uuidv4(), uuidv4()];
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create files
|
||||
for (const fileId of fileIds) {
|
||||
await methods.createFile({
|
||||
user: authorId,
|
||||
file_id: fileId,
|
||||
filename: `file-${fileId}.txt`,
|
||||
filepath: `/uploads/${fileId}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Create agent with only first two files attached
|
||||
const agent = await methods.createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileIds[0], fileIds[1]],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant EDIT permission to user on the agent
|
||||
const editorRole = (await AccessRole.findOne({
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
}).lean()) as LeanAccessRole | null;
|
||||
|
||||
if (editorRole) {
|
||||
await aclMethods.grantPermission(
|
||||
PrincipalType.USER,
|
||||
userId,
|
||||
ResourceType.AGENT,
|
||||
agent._id as string | Types.ObjectId,
|
||||
editorRole.permBits,
|
||||
authorId,
|
||||
undefined,
|
||||
editorRole._id,
|
||||
);
|
||||
}
|
||||
|
||||
// Verify ACL entry exists for the user
|
||||
const aclEntry = (await AclEntry.findOne({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
}).lean()) as LeanAclEntry | null;
|
||||
|
||||
expect(aclEntry).toBeTruthy();
|
||||
|
||||
// Check that agent has correct file_ids in tool_resources
|
||||
const agentRecord = await methods.getAgent({ id: agentId });
|
||||
const toolResources = agentRecord?.tool_resources as AgentToolResources | undefined;
|
||||
expect(toolResources?.file_search?.file_ids).toContain(fileIds[0]);
|
||||
expect(toolResources?.file_search?.file_ids).toContain(fileIds[1]);
|
||||
expect(toolResources?.file_search?.file_ids).not.toContain(fileIds[2]);
|
||||
expect(toolResources?.file_search?.file_ids).not.toContain(fileIds[3]);
|
||||
});
|
||||
|
||||
it('should grant access to agent author via ACL', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const agent = await methods.createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
});
|
||||
|
||||
// Grant owner permissions
|
||||
const ownerRole = (await AccessRole.findOne({
|
||||
accessRoleId: AccessRoleIds.AGENT_OWNER,
|
||||
}).lean()) as LeanAccessRole | null;
|
||||
|
||||
if (ownerRole) {
|
||||
await aclMethods.grantPermission(
|
||||
PrincipalType.USER,
|
||||
authorId,
|
||||
ResourceType.AGENT,
|
||||
agent._id as string | Types.ObjectId,
|
||||
ownerRole.permBits,
|
||||
authorId,
|
||||
undefined,
|
||||
ownerRole._id,
|
||||
);
|
||||
}
|
||||
|
||||
// Author should have full permission bits on the agent
|
||||
const hasView = await aclMethods.hasPermission(
|
||||
[{ principalType: PrincipalType.USER, principalId: authorId }],
|
||||
ResourceType.AGENT,
|
||||
agent._id as string | Types.ObjectId,
|
||||
PermissionBits.VIEW,
|
||||
);
|
||||
|
||||
const hasEdit = await aclMethods.hasPermission(
|
||||
[{ principalType: PrincipalType.USER, principalId: authorId }],
|
||||
ResourceType.AGENT,
|
||||
agent._id as string | Types.ObjectId,
|
||||
PermissionBits.EDIT,
|
||||
);
|
||||
|
||||
expect(hasView).toBe(true);
|
||||
expect(hasEdit).toBe(true);
|
||||
});
|
||||
|
||||
it('should deny access when no ACL entry exists', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = new mongoose.Types.ObjectId();
|
||||
|
||||
const hasAccess = await aclMethods.hasPermission(
|
||||
[{ principalType: PrincipalType.USER, principalId: userId }],
|
||||
ResourceType.AGENT,
|
||||
agentId,
|
||||
PermissionBits.VIEW,
|
||||
);
|
||||
|
||||
expect(hasAccess).toBe(false);
|
||||
});
|
||||
|
||||
it('should deny EDIT when user only has VIEW permission', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const agent = await methods.createAgent({
|
||||
id: agentId,
|
||||
name: 'View-Only Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
});
|
||||
|
||||
// Grant only VIEW permission
|
||||
const viewerRole = (await AccessRole.findOne({
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
}).lean()) as LeanAccessRole | null;
|
||||
|
||||
if (viewerRole) {
|
||||
await aclMethods.grantPermission(
|
||||
PrincipalType.USER,
|
||||
userId,
|
||||
ResourceType.AGENT,
|
||||
agent._id as string | Types.ObjectId,
|
||||
viewerRole.permBits,
|
||||
authorId,
|
||||
undefined,
|
||||
viewerRole._id,
|
||||
);
|
||||
}
|
||||
|
||||
const canView = await aclMethods.hasPermission(
|
||||
[{ principalType: PrincipalType.USER, principalId: userId }],
|
||||
ResourceType.AGENT,
|
||||
agent._id as string | Types.ObjectId,
|
||||
PermissionBits.VIEW,
|
||||
);
|
||||
|
||||
const canEdit = await aclMethods.hasPermission(
|
||||
[{ principalType: PrincipalType.USER, principalId: userId }],
|
||||
ResourceType.AGENT,
|
||||
agent._id as string | Types.ObjectId,
|
||||
PermissionBits.EDIT,
|
||||
);
|
||||
|
||||
expect(canView).toBe(true);
|
||||
expect(canEdit).toBe(false);
|
||||
});
|
||||
|
||||
it('should support role-based permission grants', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
role: 'ADMIN',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const agent = await methods.createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
});
|
||||
|
||||
// Grant permission to ADMIN role
|
||||
const editorRole = (await AccessRole.findOne({
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
}).lean()) as LeanAccessRole | null;
|
||||
|
||||
if (editorRole) {
|
||||
await aclMethods.grantPermission(
|
||||
PrincipalType.ROLE,
|
||||
'ADMIN',
|
||||
ResourceType.AGENT,
|
||||
agent._id as string | Types.ObjectId,
|
||||
editorRole.permBits,
|
||||
authorId,
|
||||
undefined,
|
||||
editorRole._id,
|
||||
);
|
||||
}
|
||||
|
||||
// User with ADMIN role should have access through role-based ACL
|
||||
const hasAccess = await aclMethods.hasPermission(
|
||||
[
|
||||
{ principalType: PrincipalType.USER, principalId: userId },
|
||||
{
|
||||
principalType: PrincipalType.ROLE,
|
||||
principalId: 'ADMIN' as unknown as mongoose.Types.ObjectId,
|
||||
},
|
||||
],
|
||||
ResourceType.AGENT,
|
||||
agent._id as string | Types.ObjectId,
|
||||
PermissionBits.VIEW,
|
||||
);
|
||||
|
||||
expect(hasAccess).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getFiles with file queries', () => {
|
||||
it('should return files created by user', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const fileId1 = `file_${uuidv4()}`;
|
||||
const fileId2 = `file_${uuidv4()}`;
|
||||
|
||||
await methods.createFile({
|
||||
file_id: fileId1,
|
||||
user: userId,
|
||||
filename: 'file1.txt',
|
||||
filepath: '/uploads/file1.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
|
||||
await methods.createFile({
|
||||
file_id: fileId2,
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
filename: 'file2.txt',
|
||||
filepath: '/uploads/file2.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 200,
|
||||
});
|
||||
|
||||
const files = await methods.getFiles({ file_id: { $in: [fileId1, fileId2] } });
|
||||
expect(files).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should return all files matching query', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const fileId1 = `file_${uuidv4()}`;
|
||||
const fileId2 = `file_${uuidv4()}`;
|
||||
|
||||
await methods.createFile({
|
||||
file_id: fileId1,
|
||||
user: userId,
|
||||
filename: 'file1.txt',
|
||||
filepath: '/uploads/file1.txt',
|
||||
});
|
||||
|
||||
await methods.createFile({
|
||||
file_id: fileId2,
|
||||
user: userId,
|
||||
filename: 'file2.txt',
|
||||
filepath: '/uploads/file2.txt',
|
||||
});
|
||||
|
||||
const files = await methods.getFiles({ user: userId });
|
||||
expect(files).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import { createSessionMethods, DEFAULT_REFRESH_TOKEN_EXPIRY, type SessionMethods } from './session';
|
||||
import { createTokenMethods, type TokenMethods } from './token';
|
||||
import { createRoleMethods, type RoleMethods } from './role';
|
||||
import { createRoleMethods, type RoleMethods, type RoleDeps } from './role';
|
||||
import { createUserMethods, DEFAULT_SESSION_EXPIRY, type UserMethods } from './user';
|
||||
|
||||
export { DEFAULT_REFRESH_TOKEN_EXPIRY, DEFAULT_SESSION_EXPIRY };
|
||||
|
|
@ -21,6 +21,34 @@ import { createAccessRoleMethods, type AccessRoleMethods } from './accessRole';
|
|||
import { createUserGroupMethods, type UserGroupMethods } from './userGroup';
|
||||
import { createAclEntryMethods, type AclEntryMethods } from './aclEntry';
|
||||
import { createShareMethods, type ShareMethods } from './share';
|
||||
/* Tier 1 — Simple CRUD */
|
||||
import { createActionMethods, type ActionMethods } from './action';
|
||||
import { createAssistantMethods, type AssistantMethods } from './assistant';
|
||||
import { createBannerMethods, type BannerMethods } from './banner';
|
||||
import { createToolCallMethods, type ToolCallMethods } from './toolCall';
|
||||
import { createCategoriesMethods, type CategoriesMethods } from './categories';
|
||||
import { createPresetMethods, type PresetMethods } from './preset';
|
||||
/* Tier 2 — Moderate (service deps injected) */
|
||||
import { createConversationTagMethods, type ConversationTagMethods } from './conversationTag';
|
||||
import { createMessageMethods, type MessageMethods } from './message';
|
||||
import { createConversationMethods, type ConversationMethods } from './conversation';
|
||||
/* Tier 3 — Complex (heavier injection) */
|
||||
import {
|
||||
createTxMethods,
|
||||
type TxMethods,
|
||||
type TxDeps,
|
||||
tokenValues,
|
||||
cacheTokenValues,
|
||||
premiumTokenValues,
|
||||
defaultRate,
|
||||
} from './tx';
|
||||
import { createTransactionMethods, type TransactionMethods } from './transaction';
|
||||
import { createSpendTokensMethods, type SpendTokensMethods } from './spendTokens';
|
||||
import { createPromptMethods, type PromptMethods, type PromptDeps } from './prompt';
|
||||
/* Tier 5 — Agent */
|
||||
import { createAgentMethods, type AgentMethods, type AgentDeps } from './agent';
|
||||
|
||||
export { tokenValues, cacheTokenValues, premiumTokenValues, defaultRate };
|
||||
|
||||
export type AllMethods = UserMethods &
|
||||
SessionMethods &
|
||||
|
|
@ -36,18 +64,102 @@ export type AllMethods = UserMethods &
|
|||
AclEntryMethods &
|
||||
ShareMethods &
|
||||
AccessRoleMethods &
|
||||
PluginAuthMethods;
|
||||
PluginAuthMethods &
|
||||
ActionMethods &
|
||||
AssistantMethods &
|
||||
BannerMethods &
|
||||
ToolCallMethods &
|
||||
CategoriesMethods &
|
||||
PresetMethods &
|
||||
ConversationTagMethods &
|
||||
MessageMethods &
|
||||
ConversationMethods &
|
||||
TxMethods &
|
||||
TransactionMethods &
|
||||
SpendTokensMethods &
|
||||
PromptMethods &
|
||||
AgentMethods;
|
||||
|
||||
/** Dependencies injected from the api layer into createMethods */
|
||||
export interface CreateMethodsDeps {
|
||||
/** Matches a model name to a canonical key. From @librechat/api. */
|
||||
matchModelName?: (model: string, endpoint?: string) => string | undefined;
|
||||
/** Finds the first key in values whose key is a substring of model. From @librechat/api. */
|
||||
findMatchingPattern?: (model: string, values: Record<string, unknown>) => string | undefined;
|
||||
/** Removes all ACL permissions for a resource. From PermissionService. */
|
||||
removeAllPermissions?: (params: { resourceType: string; resourceId: unknown }) => Promise<void>;
|
||||
/** Returns a cache store for the given key. From getLogStores. */
|
||||
getCache?: RoleDeps['getCache'];
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates all database methods for all collections
|
||||
* @param mongoose - Mongoose instance
|
||||
* @param deps - Optional dependencies injected from the api layer
|
||||
*/
|
||||
export function createMethods(mongoose: typeof import('mongoose')): AllMethods {
|
||||
export function createMethods(
|
||||
mongoose: typeof import('mongoose'),
|
||||
deps: CreateMethodsDeps = {},
|
||||
): AllMethods {
|
||||
// Tier 3: tx methods need matchModelName and findMatchingPattern
|
||||
const txDeps: TxDeps = {
|
||||
matchModelName: deps.matchModelName ?? (() => undefined),
|
||||
findMatchingPattern: deps.findMatchingPattern ?? (() => undefined),
|
||||
};
|
||||
const txMethods = createTxMethods(mongoose, txDeps);
|
||||
|
||||
// Tier 3: transaction methods need tx's getMultiplier/getCacheMultiplier
|
||||
const transactionMethods = createTransactionMethods(mongoose, {
|
||||
getMultiplier: txMethods.getMultiplier,
|
||||
getCacheMultiplier: txMethods.getCacheMultiplier,
|
||||
});
|
||||
|
||||
// Tier 3: spendTokens methods need transaction methods
|
||||
const spendTokensMethods = createSpendTokensMethods(mongoose, {
|
||||
createTransaction: transactionMethods.createTransaction,
|
||||
createStructuredTransaction: transactionMethods.createStructuredTransaction,
|
||||
});
|
||||
|
||||
const messageMethods = createMessageMethods(mongoose);
|
||||
|
||||
const conversationMethods = createConversationMethods(mongoose, {
|
||||
getMessages: messageMethods.getMessages,
|
||||
deleteMessages: messageMethods.deleteMessages,
|
||||
});
|
||||
|
||||
// ACL entry methods (used internally for removeAllPermissions)
|
||||
const aclEntryMethods = createAclEntryMethods(mongoose);
|
||||
|
||||
// Internal removeAllPermissions: use deleteAclEntries from aclEntryMethods
|
||||
// instead of requiring it as an external dep from PermissionService
|
||||
const removeAllPermissions =
|
||||
deps.removeAllPermissions ??
|
||||
(async ({ resourceType, resourceId }: { resourceType: string; resourceId: unknown }) => {
|
||||
await aclEntryMethods.deleteAclEntries({ resourceType, resourceId });
|
||||
});
|
||||
|
||||
const promptDeps: PromptDeps = { removeAllPermissions };
|
||||
const promptMethods = createPromptMethods(mongoose, promptDeps);
|
||||
|
||||
// Role methods with optional cache injection
|
||||
const roleDeps: RoleDeps = { getCache: deps.getCache };
|
||||
const roleMethods = createRoleMethods(mongoose, roleDeps);
|
||||
|
||||
// Tier 1: action methods (created as variable for agent dependency)
|
||||
const actionMethods = createActionMethods(mongoose);
|
||||
|
||||
// Tier 5: agent methods need removeAllPermissions + getActions
|
||||
const agentDeps: AgentDeps = {
|
||||
removeAllPermissions,
|
||||
getActions: actionMethods.getActions,
|
||||
};
|
||||
const agentMethods = createAgentMethods(mongoose, agentDeps);
|
||||
|
||||
return {
|
||||
...createUserMethods(mongoose),
|
||||
...createSessionMethods(mongoose),
|
||||
...createTokenMethods(mongoose),
|
||||
...createRoleMethods(mongoose),
|
||||
...roleMethods,
|
||||
...createKeyMethods(mongoose),
|
||||
...createFileMethods(mongoose),
|
||||
...createMemoryMethods(mongoose),
|
||||
|
|
@ -56,9 +168,27 @@ export function createMethods(mongoose: typeof import('mongoose')): AllMethods {
|
|||
...createMCPServerMethods(mongoose),
|
||||
...createAccessRoleMethods(mongoose),
|
||||
...createUserGroupMethods(mongoose),
|
||||
...createAclEntryMethods(mongoose),
|
||||
...aclEntryMethods,
|
||||
...createShareMethods(mongoose),
|
||||
...createPluginAuthMethods(mongoose),
|
||||
/* Tier 1 */
|
||||
...actionMethods,
|
||||
...createAssistantMethods(mongoose),
|
||||
...createBannerMethods(mongoose),
|
||||
...createToolCallMethods(mongoose),
|
||||
...createCategoriesMethods(mongoose),
|
||||
...createPresetMethods(mongoose),
|
||||
/* Tier 2 */
|
||||
...createConversationTagMethods(mongoose),
|
||||
...messageMethods,
|
||||
...conversationMethods,
|
||||
/* Tier 3 */
|
||||
...txMethods,
|
||||
...transactionMethods,
|
||||
...spendTokensMethods,
|
||||
...promptMethods,
|
||||
/* Tier 5 */
|
||||
...agentMethods,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -78,4 +208,18 @@ export type {
|
|||
ShareMethods,
|
||||
AccessRoleMethods,
|
||||
PluginAuthMethods,
|
||||
ActionMethods,
|
||||
AssistantMethods,
|
||||
BannerMethods,
|
||||
ToolCallMethods,
|
||||
CategoriesMethods,
|
||||
PresetMethods,
|
||||
ConversationTagMethods,
|
||||
MessageMethods,
|
||||
ConversationMethods,
|
||||
TxMethods,
|
||||
TransactionMethods,
|
||||
SpendTokensMethods,
|
||||
PromptMethods,
|
||||
AgentMethods,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -158,12 +158,28 @@ export function createMemoryMethods(mongoose: typeof import('mongoose')) {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes all memory entries for a user
|
||||
*/
|
||||
async function deleteAllUserMemories(userId: string | Types.ObjectId): Promise<number> {
|
||||
try {
|
||||
const MemoryEntry = mongoose.models.MemoryEntry;
|
||||
const result = await MemoryEntry.deleteMany({ userId });
|
||||
return result.deletedCount;
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to delete all user memories: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
setMemory,
|
||||
createMemory,
|
||||
deleteMemory,
|
||||
getAllUserMemories,
|
||||
getFormattedMemories,
|
||||
deleteAllUserMemories,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
940
packages/data-schemas/src/methods/message.spec.ts
Normal file
940
packages/data-schemas/src/methods/message.spec.ts
Normal file
|
|
@ -0,0 +1,940 @@
|
|||
import mongoose from 'mongoose';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import type { IMessage } from '..';
|
||||
import { createMessageMethods } from './message';
|
||||
import { createModels } from '../models';
|
||||
|
||||
jest.mock('~/config/winston', () => ({
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
info: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
}));
|
||||
|
||||
let mongoServer: InstanceType<typeof MongoMemoryServer>;
|
||||
let Message: mongoose.Model<IMessage>;
|
||||
let saveMessage: ReturnType<typeof createMessageMethods>['saveMessage'];
|
||||
let getMessages: ReturnType<typeof createMessageMethods>['getMessages'];
|
||||
let updateMessage: ReturnType<typeof createMessageMethods>['updateMessage'];
|
||||
let deleteMessages: ReturnType<typeof createMessageMethods>['deleteMessages'];
|
||||
let bulkSaveMessages: ReturnType<typeof createMessageMethods>['bulkSaveMessages'];
|
||||
let updateMessageText: ReturnType<typeof createMessageMethods>['updateMessageText'];
|
||||
let deleteMessagesSince: ReturnType<typeof createMessageMethods>['deleteMessagesSince'];
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
|
||||
const models = createModels(mongoose);
|
||||
Object.assign(mongoose.models, models);
|
||||
Message = mongoose.models.Message;
|
||||
|
||||
const methods = createMessageMethods(mongoose);
|
||||
saveMessage = methods.saveMessage;
|
||||
getMessages = methods.getMessages;
|
||||
updateMessage = methods.updateMessage;
|
||||
deleteMessages = methods.deleteMessages;
|
||||
bulkSaveMessages = methods.bulkSaveMessages;
|
||||
updateMessageText = methods.updateMessageText;
|
||||
deleteMessagesSince = methods.deleteMessagesSince;
|
||||
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
describe('Message Operations', () => {
|
||||
let mockCtx: {
|
||||
userId: string;
|
||||
isTemporary?: boolean;
|
||||
interfaceConfig?: { temporaryChatRetention?: number };
|
||||
};
|
||||
let mockMessageData: Partial<IMessage> = {
|
||||
messageId: 'msg123',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Hello, world!',
|
||||
user: 'user123',
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
// Clear database
|
||||
await Message.deleteMany({});
|
||||
|
||||
mockCtx = {
|
||||
userId: 'user123',
|
||||
interfaceConfig: {
|
||||
temporaryChatRetention: 24, // Default 24 hours
|
||||
},
|
||||
};
|
||||
|
||||
mockMessageData = {
|
||||
messageId: 'msg123',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Hello, world!',
|
||||
user: 'user123',
|
||||
};
|
||||
});
|
||||
|
||||
describe('saveMessage', () => {
|
||||
it('should save a message for an authenticated user', async () => {
|
||||
const result = await saveMessage(mockCtx, mockMessageData);
|
||||
|
||||
expect(result?.messageId).toBe('msg123');
|
||||
expect(result?.user).toBe('user123');
|
||||
expect(result?.text).toBe('Hello, world!');
|
||||
|
||||
// Verify the message was actually saved to the database
|
||||
const savedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' });
|
||||
expect(savedMessage).toBeTruthy();
|
||||
expect(savedMessage?.text).toBe('Hello, world!');
|
||||
});
|
||||
|
||||
it('should throw an error for unauthenticated user', async () => {
|
||||
mockCtx.userId = null as unknown as string;
|
||||
await expect(saveMessage(mockCtx, mockMessageData)).rejects.toThrow('User not authenticated');
|
||||
});
|
||||
|
||||
it('should handle invalid conversation ID gracefully', async () => {
|
||||
mockMessageData.conversationId = 'invalid-id';
|
||||
const result = await saveMessage(mockCtx, mockMessageData);
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateMessageText', () => {
|
||||
it('should update message text for the authenticated user', async () => {
|
||||
// First save a message
|
||||
await saveMessage(mockCtx, mockMessageData);
|
||||
|
||||
// Then update it
|
||||
await updateMessageText(mockCtx.userId, { messageId: 'msg123', text: 'Updated text' });
|
||||
|
||||
// Verify the update
|
||||
const updatedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' });
|
||||
expect(updatedMessage?.text).toBe('Updated text');
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateMessage', () => {
|
||||
it('should update a message for the authenticated user', async () => {
|
||||
// First save a message
|
||||
await saveMessage(mockCtx, mockMessageData);
|
||||
|
||||
const result = await updateMessage(mockCtx.userId, {
|
||||
messageId: 'msg123',
|
||||
text: 'Updated text',
|
||||
});
|
||||
|
||||
expect(result?.messageId).toBe('msg123');
|
||||
expect(result?.text).toBe('Updated text');
|
||||
|
||||
// Verify in database
|
||||
const updatedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' });
|
||||
expect(updatedMessage?.text).toBe('Updated text');
|
||||
});
|
||||
|
||||
it('should throw an error if message is not found', async () => {
|
||||
await expect(
|
||||
updateMessage(mockCtx.userId, { messageId: 'nonexistent', text: 'Test' }),
|
||||
).rejects.toThrow('Message not found or user not authorized.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMessagesSince', () => {
|
||||
it('should delete messages only for the authenticated user', async () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
// Create multiple messages in the same conversation
|
||||
await saveMessage(mockCtx, {
|
||||
messageId: 'msg1',
|
||||
conversationId,
|
||||
text: 'First message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
await saveMessage(mockCtx, {
|
||||
messageId: 'msg2',
|
||||
conversationId,
|
||||
text: 'Second message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
await saveMessage(mockCtx, {
|
||||
messageId: 'msg3',
|
||||
conversationId,
|
||||
text: 'Third message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
// Delete messages since message2 (this should only delete messages created AFTER msg2)
|
||||
await deleteMessagesSince(mockCtx.userId, {
|
||||
messageId: 'msg2',
|
||||
conversationId,
|
||||
});
|
||||
|
||||
// Verify msg1 and msg2 remain, msg3 is deleted
|
||||
const remainingMessages = await Message.find({ conversationId, user: 'user123' });
|
||||
expect(remainingMessages).toHaveLength(2);
|
||||
expect(remainingMessages.map((m) => m.messageId)).toContain('msg1');
|
||||
expect(remainingMessages.map((m) => m.messageId)).toContain('msg2');
|
||||
expect(remainingMessages.map((m) => m.messageId)).not.toContain('msg3');
|
||||
});
|
||||
|
||||
it('should return undefined if no message is found', async () => {
|
||||
const result = await deleteMessagesSince(mockCtx.userId, {
|
||||
messageId: 'nonexistent',
|
||||
conversationId: 'convo123',
|
||||
});
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getMessages', () => {
|
||||
it('should retrieve messages with the correct filter', async () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
// Save some messages
|
||||
await saveMessage(mockCtx, {
|
||||
messageId: 'msg1',
|
||||
conversationId,
|
||||
text: 'First message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
await saveMessage(mockCtx, {
|
||||
messageId: 'msg2',
|
||||
conversationId,
|
||||
text: 'Second message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
const messages = await getMessages({ conversationId });
|
||||
expect(messages).toHaveLength(2);
|
||||
expect(messages[0].text).toBe('First message');
|
||||
expect(messages[1].text).toBe('Second message');
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMessages', () => {
|
||||
it('should delete messages with the correct filter', async () => {
|
||||
// Save some messages for different users
|
||||
await saveMessage(mockCtx, mockMessageData);
|
||||
await saveMessage(
|
||||
{ userId: 'user456' },
|
||||
{
|
||||
messageId: 'msg456',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Other user message',
|
||||
user: 'user456',
|
||||
},
|
||||
);
|
||||
|
||||
await deleteMessages({ user: 'user123' });
|
||||
|
||||
// Verify only user123's messages were deleted
|
||||
const user123Messages = await Message.find({ user: 'user123' });
|
||||
const user456Messages = await Message.find({ user: 'user456' });
|
||||
|
||||
expect(user123Messages).toHaveLength(0);
|
||||
expect(user456Messages).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Conversation Hijacking Prevention', () => {
|
||||
it("should not allow editing a message in another user's conversation", async () => {
|
||||
const victimConversationId = uuidv4();
|
||||
const victimMessageId = 'victim-msg-123';
|
||||
|
||||
// First, save a message as the victim (but we'll try to edit as attacker)
|
||||
await saveMessage(
|
||||
{ userId: 'victim123' },
|
||||
{
|
||||
messageId: victimMessageId,
|
||||
conversationId: victimConversationId,
|
||||
text: 'Victim message',
|
||||
user: 'victim123',
|
||||
},
|
||||
);
|
||||
|
||||
// Attacker tries to edit the victim's message
|
||||
await expect(
|
||||
updateMessage('attacker123', {
|
||||
messageId: victimMessageId,
|
||||
conversationId: victimConversationId,
|
||||
text: 'Hacked message',
|
||||
}),
|
||||
).rejects.toThrow('Message not found or user not authorized.');
|
||||
|
||||
// Verify the original message is unchanged
|
||||
const originalMessage = await Message.findOne({
|
||||
messageId: victimMessageId,
|
||||
user: 'victim123',
|
||||
});
|
||||
expect(originalMessage?.text).toBe('Victim message');
|
||||
});
|
||||
|
||||
it("should not allow deleting messages from another user's conversation", async () => {
|
||||
const victimConversationId = uuidv4();
|
||||
const victimMessageId = 'victim-msg-123';
|
||||
|
||||
// Save a message as the victim
|
||||
await saveMessage(
|
||||
{ userId: 'victim123' },
|
||||
{
|
||||
messageId: victimMessageId,
|
||||
conversationId: victimConversationId,
|
||||
text: 'Victim message',
|
||||
user: 'victim123',
|
||||
},
|
||||
);
|
||||
|
||||
// Attacker tries to delete from victim's conversation
|
||||
const result = await deleteMessagesSince('attacker123', {
|
||||
messageId: victimMessageId,
|
||||
conversationId: victimConversationId,
|
||||
});
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
|
||||
// Verify the victim's message still exists
|
||||
const victimMessage = await Message.findOne({
|
||||
messageId: victimMessageId,
|
||||
user: 'victim123',
|
||||
});
|
||||
expect(victimMessage).toBeTruthy();
|
||||
expect(victimMessage?.text).toBe('Victim message');
|
||||
});
|
||||
|
||||
it("should not allow inserting a new message into another user's conversation", async () => {
|
||||
const victimConversationId = uuidv4();
|
||||
|
||||
// Attacker tries to save a message - this should succeed but with attacker's user ID
|
||||
const result = await saveMessage(
|
||||
{ userId: 'attacker123' },
|
||||
{
|
||||
conversationId: victimConversationId,
|
||||
text: 'Inserted malicious message',
|
||||
messageId: 'new-msg-123',
|
||||
user: 'attacker123',
|
||||
},
|
||||
);
|
||||
|
||||
expect(result).toBeTruthy();
|
||||
expect(result?.user).toBe('attacker123');
|
||||
|
||||
// Verify the message was saved with the attacker's user ID, not as an anonymous message
|
||||
const savedMessage = await Message.findOne({ messageId: 'new-msg-123' });
|
||||
expect(savedMessage?.user).toBe('attacker123');
|
||||
expect(savedMessage?.conversationId).toBe(victimConversationId);
|
||||
});
|
||||
|
||||
it('should allow retrieving messages from any conversation', async () => {
|
||||
const victimConversationId = uuidv4();
|
||||
|
||||
// Save a message in the victim's conversation
|
||||
await saveMessage(
|
||||
{ userId: 'victim123' },
|
||||
{
|
||||
messageId: 'victim-msg',
|
||||
conversationId: victimConversationId,
|
||||
text: 'Victim message',
|
||||
user: 'victim123',
|
||||
},
|
||||
);
|
||||
|
||||
// Anyone should be able to retrieve messages by conversation ID
|
||||
const messages = await getMessages({ conversationId: victimConversationId });
|
||||
expect(messages).toHaveLength(1);
|
||||
expect(messages[0].text).toBe('Victim message');
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTemporary message handling', () => {
|
||||
beforeEach(() => {
|
||||
// Reset mocks before each test
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should save a message with expiredAt when isTemporary is true', async () => {
|
||||
// Mock app config with 24 hour retention
|
||||
mockCtx.interfaceConfig = { temporaryChatRetention: 24 };
|
||||
|
||||
mockCtx.isTemporary = true;
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockCtx, mockMessageData);
|
||||
const afterSave = new Date();
|
||||
|
||||
expect(result?.messageId).toBe('msg123');
|
||||
expect(result?.expiredAt).toBeDefined();
|
||||
expect(result?.expiredAt).toBeInstanceOf(Date);
|
||||
|
||||
// Verify expiredAt is approximately 24 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result?.expiredAt ?? 0);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
new Date(afterSave.getTime() + 24 * 60 * 60 * 1000 + 1000).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should save a message without expiredAt when isTemporary is false', async () => {
|
||||
mockCtx.isTemporary = false;
|
||||
|
||||
const result = await saveMessage(mockCtx, mockMessageData);
|
||||
|
||||
expect(result?.messageId).toBe('msg123');
|
||||
expect(result?.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should save a message without expiredAt when isTemporary is not provided', async () => {
|
||||
// No isTemporary set
|
||||
|
||||
const result = await saveMessage(mockCtx, mockMessageData);
|
||||
|
||||
expect(result?.messageId).toBe('msg123');
|
||||
expect(result?.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use custom retention period from config', async () => {
|
||||
// Mock app config with 48 hour retention
|
||||
mockCtx.interfaceConfig = { temporaryChatRetention: 48 };
|
||||
|
||||
mockCtx.isTemporary = true;
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockCtx, mockMessageData);
|
||||
|
||||
expect(result?.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 48 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 48 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result?.expiredAt ?? 0);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle minimum retention period (1 hour)', async () => {
|
||||
// Mock app config with less than minimum retention
|
||||
mockCtx.interfaceConfig = { temporaryChatRetention: 0.5 }; // Half hour - should be clamped to 1 hour
|
||||
|
||||
mockCtx.isTemporary = true;
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockCtx, mockMessageData);
|
||||
|
||||
expect(result?.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 1 hour in the future (minimum)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 1 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result?.expiredAt ?? 0);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle maximum retention period (8760 hours)', async () => {
|
||||
// Mock app config with more than maximum retention
|
||||
mockCtx.interfaceConfig = { temporaryChatRetention: 10000 }; // Should be clamped to 8760 hours
|
||||
|
||||
mockCtx.isTemporary = true;
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockCtx, mockMessageData);
|
||||
|
||||
expect(result?.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 8760 hours (1 year) in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 8760 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result?.expiredAt ?? 0);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle missing config gracefully', async () => {
|
||||
// Simulate missing config - should use default retention period
|
||||
delete mockCtx.interfaceConfig;
|
||||
|
||||
mockCtx.isTemporary = true;
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockCtx, mockMessageData);
|
||||
const afterSave = new Date();
|
||||
|
||||
// Should still save the message with default retention period (30 days)
|
||||
expect(result?.messageId).toBe('msg123');
|
||||
expect(result?.expiredAt).toBeDefined();
|
||||
expect(result?.expiredAt).toBeInstanceOf(Date);
|
||||
|
||||
// Verify expiredAt is approximately 30 days in the future (720 hours)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 720 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result?.expiredAt ?? 0);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
new Date(afterSave.getTime() + 720 * 60 * 60 * 1000 + 1000).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use default retention when config is not provided', async () => {
|
||||
// Mock getAppConfig to return empty config
|
||||
mockCtx.interfaceConfig = undefined; // Empty config
|
||||
|
||||
mockCtx.isTemporary = true;
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockCtx, mockMessageData);
|
||||
|
||||
expect(result?.expiredAt).toBeDefined();
|
||||
|
||||
// Default retention is 30 days (720 hours)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 30 * 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result?.expiredAt ?? 0);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not update expiredAt on message update', async () => {
|
||||
// First save a temporary message
|
||||
mockCtx.interfaceConfig = { temporaryChatRetention: 24 };
|
||||
|
||||
mockCtx.isTemporary = true;
|
||||
const savedMessage = await saveMessage(mockCtx, mockMessageData);
|
||||
const originalExpiredAt = savedMessage?.expiredAt;
|
||||
|
||||
// Now update the message without isTemporary flag
|
||||
mockCtx.isTemporary = undefined;
|
||||
const updatedMessage = await updateMessage(mockCtx.userId, {
|
||||
messageId: 'msg123',
|
||||
text: 'Updated text',
|
||||
});
|
||||
|
||||
// expiredAt should not be in the returned updated message object
|
||||
expect(updatedMessage?.expiredAt).toBeUndefined();
|
||||
|
||||
// Verify in database that expiredAt wasn't changed
|
||||
const dbMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' });
|
||||
expect(dbMessage?.expiredAt).toEqual(originalExpiredAt);
|
||||
});
|
||||
|
||||
it('should preserve expiredAt when saving existing temporary message', async () => {
|
||||
// First save a temporary message
|
||||
mockCtx.interfaceConfig = { temporaryChatRetention: 24 };
|
||||
|
||||
mockCtx.isTemporary = true;
|
||||
const firstSave = await saveMessage(mockCtx, mockMessageData);
|
||||
const originalExpiredAt = firstSave?.expiredAt;
|
||||
|
||||
// Wait a bit to ensure time difference
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Save again with same messageId but different text
|
||||
const updatedData = { ...mockMessageData, text: 'Updated text' };
|
||||
const secondSave = await saveMessage(mockCtx, updatedData);
|
||||
|
||||
// Should update text but create new expiredAt
|
||||
expect(secondSave?.text).toBe('Updated text');
|
||||
expect(secondSave?.expiredAt).toBeDefined();
|
||||
expect(new Date(secondSave?.expiredAt ?? 0).getTime()).toBeGreaterThan(
|
||||
new Date(originalExpiredAt ?? 0).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle bulk operations with temporary messages', async () => {
|
||||
// This test verifies bulkSaveMessages doesn't interfere with expiredAt
|
||||
const messages = [
|
||||
{
|
||||
messageId: 'bulk1',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Bulk message 1',
|
||||
user: 'user123',
|
||||
expiredAt: new Date(Date.now() + 24 * 60 * 60 * 1000),
|
||||
},
|
||||
{
|
||||
messageId: 'bulk2',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Bulk message 2',
|
||||
user: 'user123',
|
||||
expiredAt: null,
|
||||
},
|
||||
];
|
||||
|
||||
await bulkSaveMessages(messages);
|
||||
|
||||
const savedMessages = await Message.find({
|
||||
messageId: { $in: ['bulk1', 'bulk2'] },
|
||||
}).lean();
|
||||
|
||||
expect(savedMessages).toHaveLength(2);
|
||||
|
||||
const bulk1 = savedMessages.find((m) => m.messageId === 'bulk1');
|
||||
const bulk2 = savedMessages.find((m) => m.messageId === 'bulk2');
|
||||
|
||||
expect(bulk1?.expiredAt).toBeDefined();
|
||||
expect(bulk2?.expiredAt).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Message cursor pagination', () => {
|
||||
/**
|
||||
* Helper to create messages with specific timestamps
|
||||
* Uses collection.insertOne to bypass Mongoose timestamps
|
||||
*/
|
||||
const createMessageWithTimestamp = async (
|
||||
index: number,
|
||||
conversationId: string,
|
||||
createdAt: Date,
|
||||
) => {
|
||||
const messageId = uuidv4();
|
||||
await Message.collection.insertOne({
|
||||
messageId,
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
text: `Message ${index}`,
|
||||
isCreatedByUser: index % 2 === 0,
|
||||
createdAt,
|
||||
updatedAt: createdAt,
|
||||
});
|
||||
return Message.findOne({ messageId }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Simulates the pagination logic from api/server/routes/messages.js
|
||||
* This tests the exact query pattern used in the route
|
||||
*/
|
||||
const getMessagesByCursor = async ({
|
||||
conversationId,
|
||||
user,
|
||||
pageSize = 25,
|
||||
cursor = null as string | null,
|
||||
sortBy = 'createdAt',
|
||||
sortDirection = 'desc',
|
||||
}: {
|
||||
conversationId: string;
|
||||
user: string;
|
||||
pageSize?: number;
|
||||
cursor?: string | null;
|
||||
sortBy?: string;
|
||||
sortDirection?: string;
|
||||
}) => {
|
||||
const sortOrder = sortDirection === 'asc' ? 1 : -1;
|
||||
const sortField = ['createdAt', 'updatedAt'].includes(sortBy) ? sortBy : 'createdAt';
|
||||
const cursorOperator = sortDirection === 'asc' ? '$gt' : '$lt';
|
||||
|
||||
const filter: Record<string, unknown> = { conversationId, user };
|
||||
if (cursor) {
|
||||
filter[sortField] = { [cursorOperator]: new Date(cursor) };
|
||||
}
|
||||
|
||||
const messages = await Message.find(filter)
|
||||
.sort({ [sortField]: sortOrder })
|
||||
.limit(pageSize + 1)
|
||||
.lean();
|
||||
|
||||
let nextCursor: string | null = null;
|
||||
if (messages.length > pageSize) {
|
||||
messages.pop(); // Remove extra item used to detect next page
|
||||
// Create cursor from the last RETURNED item (not the popped one)
|
||||
nextCursor = (messages[messages.length - 1] as Record<string, unknown>)[
|
||||
sortField
|
||||
] as string;
|
||||
}
|
||||
|
||||
return { messages, nextCursor };
|
||||
};
|
||||
|
||||
it('should return messages for a conversation with pagination', async () => {
|
||||
const conversationId = uuidv4();
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
|
||||
// Create 30 messages to test pagination
|
||||
for (let i = 0; i < 30; i++) {
|
||||
const createdAt = new Date(baseTime.getTime() - i * 60000); // Each 1 minute apart
|
||||
await createMessageWithTimestamp(i, conversationId, createdAt);
|
||||
}
|
||||
|
||||
// Fetch first page (pageSize 25)
|
||||
const page1 = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 25,
|
||||
});
|
||||
|
||||
expect(page1.messages).toHaveLength(25);
|
||||
expect(page1.nextCursor).toBeTruthy();
|
||||
|
||||
// Fetch second page using cursor
|
||||
const page2 = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 25,
|
||||
cursor: page1.nextCursor,
|
||||
});
|
||||
|
||||
// Should get remaining 5 messages
|
||||
expect(page2.messages).toHaveLength(5);
|
||||
expect(page2.nextCursor).toBeNull();
|
||||
|
||||
// Verify no duplicates and no gaps
|
||||
const allMessageIds = [
|
||||
...page1.messages.map((m) => m.messageId),
|
||||
...page2.messages.map((m) => m.messageId),
|
||||
];
|
||||
const uniqueIds = new Set(allMessageIds);
|
||||
|
||||
expect(uniqueIds.size).toBe(30); // All 30 messages accounted for
|
||||
expect(allMessageIds.length).toBe(30); // No duplicates
|
||||
});
|
||||
|
||||
it('should not skip message at page boundary (item 26 bug fix)', async () => {
|
||||
const conversationId = uuidv4();
|
||||
const baseTime = new Date('2026-01-01T12:00:00.000Z');
|
||||
|
||||
// Create exactly 26 messages
|
||||
const messages: (IMessage | null)[] = [];
|
||||
for (let i = 0; i < 26; i++) {
|
||||
const createdAt = new Date(baseTime.getTime() - i * 60000);
|
||||
const msg = await createMessageWithTimestamp(i, conversationId, createdAt);
|
||||
messages.push(msg);
|
||||
}
|
||||
|
||||
// The 26th message (index 25) should be on page 2
|
||||
const item26 = messages[25];
|
||||
|
||||
// Fetch first page with pageSize 25
|
||||
const page1 = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 25,
|
||||
});
|
||||
|
||||
expect(page1.messages).toHaveLength(25);
|
||||
expect(page1.nextCursor).toBeTruthy();
|
||||
|
||||
// Item 26 should NOT be in page 1
|
||||
const page1Ids = page1.messages.map((m) => m.messageId);
|
||||
expect(page1Ids).not.toContain(item26!.messageId);
|
||||
|
||||
// Fetch second page
|
||||
const page2 = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 25,
|
||||
cursor: page1.nextCursor,
|
||||
});
|
||||
|
||||
// Item 26 MUST be in page 2 (this was the bug - it was being skipped)
|
||||
expect(page2.messages).toHaveLength(1);
|
||||
expect((page2.messages[0] as { messageId: string }).messageId).toBe(item26!.messageId);
|
||||
});
|
||||
|
||||
it('should sort by createdAt DESC by default', async () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
// Create messages with specific timestamps
|
||||
const msg1 = await createMessageWithTimestamp(
|
||||
1,
|
||||
conversationId,
|
||||
new Date('2026-01-01T00:00:00.000Z'),
|
||||
);
|
||||
const msg2 = await createMessageWithTimestamp(
|
||||
2,
|
||||
conversationId,
|
||||
new Date('2026-01-02T00:00:00.000Z'),
|
||||
);
|
||||
const msg3 = await createMessageWithTimestamp(
|
||||
3,
|
||||
conversationId,
|
||||
new Date('2026-01-03T00:00:00.000Z'),
|
||||
);
|
||||
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
// Should be sorted by createdAt DESC (newest first) by default
|
||||
expect(result?.messages).toHaveLength(3);
|
||||
expect((result?.messages[0] as { messageId: string }).messageId).toBe(msg3!.messageId);
|
||||
expect((result?.messages[1] as { messageId: string }).messageId).toBe(msg2!.messageId);
|
||||
expect((result?.messages[2] as { messageId: string }).messageId).toBe(msg1!.messageId);
|
||||
});
|
||||
|
||||
it('should support ascending sort direction', async () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
const msg1 = await createMessageWithTimestamp(
|
||||
1,
|
||||
conversationId,
|
||||
new Date('2026-01-01T00:00:00.000Z'),
|
||||
);
|
||||
const msg2 = await createMessageWithTimestamp(
|
||||
2,
|
||||
conversationId,
|
||||
new Date('2026-01-02T00:00:00.000Z'),
|
||||
);
|
||||
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
sortDirection: 'asc',
|
||||
});
|
||||
|
||||
// Should be sorted by createdAt ASC (oldest first)
|
||||
expect(result?.messages).toHaveLength(2);
|
||||
expect((result?.messages[0] as { messageId: string }).messageId).toBe(msg1!.messageId);
|
||||
expect((result?.messages[1] as { messageId: string }).messageId).toBe(msg2!.messageId);
|
||||
});
|
||||
|
||||
it('should handle empty conversation', async () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
expect(result?.messages).toHaveLength(0);
|
||||
expect(result?.nextCursor).toBeNull();
|
||||
});
|
||||
|
||||
it('should only return messages for the specified user', async () => {
|
||||
const conversationId = uuidv4();
|
||||
const createdAt = new Date();
|
||||
|
||||
// Create a message for user123
|
||||
await Message.collection.insertOne({
|
||||
messageId: uuidv4(),
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
text: 'User message',
|
||||
createdAt,
|
||||
updatedAt: createdAt,
|
||||
});
|
||||
|
||||
// Create a message for a different user
|
||||
await Message.collection.insertOne({
|
||||
messageId: uuidv4(),
|
||||
conversationId,
|
||||
user: 'otherUser',
|
||||
text: 'Other user message',
|
||||
createdAt,
|
||||
updatedAt: createdAt,
|
||||
});
|
||||
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
// Should only return user123's message
|
||||
expect(result?.messages).toHaveLength(1);
|
||||
expect((result?.messages[0] as { user: string }).user).toBe('user123');
|
||||
});
|
||||
|
||||
it('should handle exactly pageSize number of messages (no next page)', async () => {
|
||||
const conversationId = uuidv4();
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
|
||||
// Create exactly 25 messages (equal to default pageSize)
|
||||
for (let i = 0; i < 25; i++) {
|
||||
const createdAt = new Date(baseTime.getTime() - i * 60000);
|
||||
await createMessageWithTimestamp(i, conversationId, createdAt);
|
||||
}
|
||||
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 25,
|
||||
});
|
||||
|
||||
expect(result?.messages).toHaveLength(25);
|
||||
expect(result?.nextCursor).toBeNull(); // No next page
|
||||
});
|
||||
|
||||
it('should handle pageSize of 1', async () => {
|
||||
const conversationId = uuidv4();
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
|
||||
// Create 3 messages
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const createdAt = new Date(baseTime.getTime() - i * 60000);
|
||||
await createMessageWithTimestamp(i, conversationId, createdAt);
|
||||
}
|
||||
|
||||
// Fetch with pageSize 1
|
||||
let cursor: string | null = null;
|
||||
const allMessages: unknown[] = [];
|
||||
|
||||
for (let page = 0; page < 5; page++) {
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 1,
|
||||
cursor,
|
||||
});
|
||||
|
||||
allMessages.push(...(result?.messages ?? []));
|
||||
cursor = result?.nextCursor;
|
||||
|
||||
if (!cursor) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Should get all 3 messages without duplicates
|
||||
expect(allMessages).toHaveLength(3);
|
||||
const uniqueIds = new Set(allMessages.map((m) => (m as { messageId: string }).messageId));
|
||||
expect(uniqueIds.size).toBe(3);
|
||||
});
|
||||
|
||||
it('should handle messages with same createdAt timestamp', async () => {
|
||||
const conversationId = uuidv4();
|
||||
const sameTime = new Date('2026-01-01T12:00:00.000Z');
|
||||
|
||||
// Create multiple messages with the exact same timestamp
|
||||
const messages: (IMessage | null)[] = [];
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const msg = await createMessageWithTimestamp(i, conversationId, sameTime);
|
||||
messages.push(msg);
|
||||
}
|
||||
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 10,
|
||||
});
|
||||
|
||||
// All messages should be returned
|
||||
expect(result?.messages).toHaveLength(5);
|
||||
});
|
||||
});
|
||||
});
|
||||
399
packages/data-schemas/src/methods/message.ts
Normal file
399
packages/data-schemas/src/methods/message.ts
Normal file
|
|
@ -0,0 +1,399 @@
|
|||
import type { DeleteResult, FilterQuery, Model } from 'mongoose';
|
||||
import logger from '~/config/winston';
|
||||
import { createTempChatExpirationDate } from '~/utils/tempChatRetention';
|
||||
import type { AppConfig, IMessage } from '~/types';
|
||||
|
||||
/** Simple UUID v4 regex to replace zod validation */
|
||||
const UUID_REGEX = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i;
|
||||
|
||||
export interface MessageMethods {
|
||||
saveMessage(
|
||||
ctx: { userId: string; isTemporary?: boolean; interfaceConfig?: AppConfig['interfaceConfig'] },
|
||||
params: Partial<IMessage> & { newMessageId?: string },
|
||||
metadata?: { context?: string },
|
||||
): Promise<IMessage | null | undefined>;
|
||||
bulkSaveMessages(
|
||||
messages: Array<Partial<IMessage>>,
|
||||
overrideTimestamp?: boolean,
|
||||
): Promise<unknown>;
|
||||
recordMessage(params: {
|
||||
user: string;
|
||||
endpoint?: string;
|
||||
messageId: string;
|
||||
conversationId?: string;
|
||||
parentMessageId?: string;
|
||||
[key: string]: unknown;
|
||||
}): Promise<IMessage | null>;
|
||||
updateMessageText(userId: string, params: { messageId: string; text: string }): Promise<void>;
|
||||
updateMessage(
|
||||
userId: string,
|
||||
message: Partial<IMessage> & { newMessageId?: string },
|
||||
metadata?: { context?: string },
|
||||
): Promise<Partial<IMessage>>;
|
||||
deleteMessagesSince(
|
||||
userId: string,
|
||||
params: { messageId: string; conversationId: string },
|
||||
): Promise<DeleteResult>;
|
||||
getMessages(filter: FilterQuery<IMessage>, select?: string): Promise<IMessage[]>;
|
||||
getMessage(params: { user: string; messageId: string }): Promise<IMessage | null>;
|
||||
getMessagesByCursor(
|
||||
filter: FilterQuery<IMessage>,
|
||||
options?: {
|
||||
sortField?: string;
|
||||
sortOrder?: 1 | -1;
|
||||
limit?: number;
|
||||
cursor?: string | null;
|
||||
},
|
||||
): Promise<{ messages: IMessage[]; nextCursor: string | null }>;
|
||||
searchMessages(
|
||||
query: string,
|
||||
searchOptions: Partial<IMessage>,
|
||||
hydrate?: boolean,
|
||||
): Promise<unknown>;
|
||||
deleteMessages(filter: FilterQuery<IMessage>): Promise<DeleteResult>;
|
||||
}
|
||||
|
||||
export function createMessageMethods(mongoose: typeof import('mongoose')): MessageMethods {
|
||||
/**
|
||||
* Saves a message in the database.
|
||||
*/
|
||||
async function saveMessage(
|
||||
{
|
||||
userId,
|
||||
isTemporary,
|
||||
interfaceConfig,
|
||||
}: {
|
||||
userId: string;
|
||||
isTemporary?: boolean;
|
||||
interfaceConfig?: AppConfig['interfaceConfig'];
|
||||
},
|
||||
params: Partial<IMessage> & { newMessageId?: string },
|
||||
metadata?: { context?: string },
|
||||
) {
|
||||
if (!userId) {
|
||||
throw new Error('User not authenticated');
|
||||
}
|
||||
|
||||
const conversationId = params.conversationId as string | undefined;
|
||||
if (!conversationId || !UUID_REGEX.test(conversationId)) {
|
||||
logger.warn(`Invalid conversation ID: ${conversationId}`);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
logger.info(`---Invalid conversation ID Params: ${JSON.stringify(params, null, 2)}`);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
const update: Record<string, unknown> = {
|
||||
...params,
|
||||
user: userId,
|
||||
messageId: params.newMessageId || params.messageId,
|
||||
};
|
||||
|
||||
if (isTemporary) {
|
||||
try {
|
||||
update.expiredAt = createTempChatExpirationDate(interfaceConfig);
|
||||
} catch (err) {
|
||||
logger.error('Error creating temporary chat expiration date:', err);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
update.expiredAt = null;
|
||||
}
|
||||
} else {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
if (update.tokenCount != null && isNaN(update.tokenCount as number)) {
|
||||
logger.warn(
|
||||
`Resetting invalid \`tokenCount\` for message \`${params.messageId}\`: ${update.tokenCount}`,
|
||||
);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
update.tokenCount = 0;
|
||||
}
|
||||
const message = await Message.findOneAndUpdate(
|
||||
{ messageId: params.messageId, user: userId },
|
||||
update,
|
||||
{ upsert: true, new: true },
|
||||
);
|
||||
|
||||
return message.toObject();
|
||||
} catch (err: unknown) {
|
||||
logger.error('Error saving message:', err);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
|
||||
const mongoErr = err as { code?: number; message?: string };
|
||||
if (mongoErr.code === 11000 && mongoErr.message?.includes('duplicate key error')) {
|
||||
logger.warn(`Duplicate messageId detected: ${params.messageId}. Continuing execution.`);
|
||||
|
||||
try {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
const existingMessage = await Message.findOne({
|
||||
messageId: params.messageId,
|
||||
user: userId,
|
||||
});
|
||||
|
||||
if (existingMessage) {
|
||||
return existingMessage.toObject();
|
||||
}
|
||||
|
||||
return undefined;
|
||||
} catch (findError) {
|
||||
logger.warn(
|
||||
`Could not retrieve existing message with ID ${params.messageId}: ${(findError as Error).message}`,
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves multiple messages in bulk.
|
||||
*/
|
||||
async function bulkSaveMessages(
|
||||
messages: Array<Record<string, unknown>>,
|
||||
overrideTimestamp = false,
|
||||
) {
|
||||
try {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
const bulkOps = messages.map((message) => ({
|
||||
updateOne: {
|
||||
filter: { messageId: message.messageId },
|
||||
update: message,
|
||||
timestamps: !overrideTimestamp,
|
||||
upsert: true,
|
||||
},
|
||||
}));
|
||||
const result = await Message.bulkWrite(bulkOps);
|
||||
return result;
|
||||
} catch (err) {
|
||||
logger.error('Error saving messages in bulk:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Records a message in the database (no UUID validation).
|
||||
*/
|
||||
async function recordMessage({
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest
|
||||
}: {
|
||||
user: string;
|
||||
endpoint?: string;
|
||||
messageId: string;
|
||||
conversationId?: string;
|
||||
parentMessageId?: string;
|
||||
[key: string]: unknown;
|
||||
}) {
|
||||
try {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
const message = {
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest,
|
||||
};
|
||||
|
||||
return await Message.findOneAndUpdate({ user, messageId }, message, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error('Error recording message:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the text of a message.
|
||||
*/
|
||||
async function updateMessageText(
|
||||
userId: string,
|
||||
{ messageId, text }: { messageId: string; text: string },
|
||||
) {
|
||||
try {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
await Message.updateOne({ messageId, user: userId }, { text });
|
||||
} catch (err) {
|
||||
logger.error('Error updating message text:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a message and returns sanitized fields.
|
||||
*/
|
||||
async function updateMessage(
|
||||
userId: string,
|
||||
message: { messageId: string; [key: string]: unknown },
|
||||
metadata?: { context?: string },
|
||||
) {
|
||||
try {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
const { messageId, ...update } = message;
|
||||
const updatedMessage = await Message.findOneAndUpdate({ messageId, user: userId }, update, {
|
||||
new: true,
|
||||
});
|
||||
|
||||
if (!updatedMessage) {
|
||||
throw new Error('Message not found or user not authorized.');
|
||||
}
|
||||
|
||||
return {
|
||||
messageId: updatedMessage.messageId,
|
||||
conversationId: updatedMessage.conversationId,
|
||||
parentMessageId: updatedMessage.parentMessageId,
|
||||
sender: updatedMessage.sender,
|
||||
text: updatedMessage.text,
|
||||
isCreatedByUser: updatedMessage.isCreatedByUser,
|
||||
tokenCount: updatedMessage.tokenCount,
|
||||
feedback: updatedMessage.feedback,
|
||||
};
|
||||
} catch (err) {
|
||||
logger.error('Error updating message:', err);
|
||||
if (metadata?.context) {
|
||||
logger.info(`---\`updateMessage\` context: ${metadata.context}`);
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes messages in a conversation since a specific message.
|
||||
*/
|
||||
async function deleteMessagesSince(
|
||||
userId: string,
|
||||
{ messageId, conversationId }: { messageId: string; conversationId: string },
|
||||
) {
|
||||
try {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
const message = await Message.findOne({ messageId, user: userId }).lean();
|
||||
|
||||
if (message) {
|
||||
const query = Message.find({ conversationId, user: userId });
|
||||
return await query.deleteMany({
|
||||
createdAt: { $gt: message.createdAt },
|
||||
});
|
||||
}
|
||||
return undefined;
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves messages from the database.
|
||||
*/
|
||||
async function getMessages(filter: FilterQuery<IMessage>, select?: string) {
|
||||
try {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
if (select) {
|
||||
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
|
||||
}
|
||||
|
||||
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a single message from the database.
|
||||
*/
|
||||
async function getMessage({ user, messageId }: { user: string; messageId: string }) {
|
||||
try {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
return await Message.findOne({ user, messageId }).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting message:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes messages from the database.
|
||||
*/
|
||||
async function deleteMessages(filter: FilterQuery<IMessage>) {
|
||||
try {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
return await Message.deleteMany(filter);
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves paginated messages with custom sorting and cursor support.
|
||||
*/
|
||||
async function getMessagesByCursor(
|
||||
filter: FilterQuery<IMessage>,
|
||||
options: {
|
||||
sortField?: string;
|
||||
sortOrder?: 1 | -1;
|
||||
limit?: number;
|
||||
cursor?: string | null;
|
||||
} = {},
|
||||
) {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
const { sortField = 'createdAt', sortOrder = -1, limit = 25, cursor } = options;
|
||||
const queryFilter = { ...filter };
|
||||
if (cursor) {
|
||||
queryFilter[sortField] = sortOrder === 1 ? { $gt: cursor } : { $lt: cursor };
|
||||
}
|
||||
const messages = await Message.find(queryFilter)
|
||||
.sort({ [sortField]: sortOrder })
|
||||
.limit(limit + 1)
|
||||
.lean();
|
||||
|
||||
let nextCursor: string | null = null;
|
||||
if (messages.length > limit) {
|
||||
messages.pop();
|
||||
const last = messages[messages.length - 1] as Record<string, unknown>;
|
||||
nextCursor = String(last[sortField] ?? '');
|
||||
}
|
||||
return { messages, nextCursor };
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a MeiliSearch query on the Message collection.
|
||||
* Requires the meilisearch plugin to be registered on the Message model.
|
||||
*/
|
||||
async function searchMessages(
|
||||
query: string,
|
||||
searchOptions: Record<string, unknown>,
|
||||
hydrate?: boolean,
|
||||
) {
|
||||
const Message = mongoose.models.Message as Model<IMessage> & {
|
||||
meiliSearch?: (q: string, opts: Record<string, unknown>, h?: boolean) => Promise<unknown>;
|
||||
};
|
||||
if (typeof Message.meiliSearch !== 'function') {
|
||||
throw new Error('MeiliSearch plugin not registered on Message model');
|
||||
}
|
||||
return Message.meiliSearch(query, searchOptions, hydrate);
|
||||
}
|
||||
|
||||
return {
|
||||
saveMessage,
|
||||
bulkSaveMessages,
|
||||
recordMessage,
|
||||
updateMessageText,
|
||||
updateMessage,
|
||||
deleteMessagesSince,
|
||||
getMessages,
|
||||
getMessage,
|
||||
getMessagesByCursor,
|
||||
searchMessages,
|
||||
deleteMessages,
|
||||
};
|
||||
}
|
||||
132
packages/data-schemas/src/methods/preset.ts
Normal file
132
packages/data-schemas/src/methods/preset.ts
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
import type { Model } from 'mongoose';
|
||||
import logger from '~/config/winston';
|
||||
|
||||
interface IPreset {
|
||||
user?: string;
|
||||
presetId?: string;
|
||||
order?: number;
|
||||
defaultPreset?: boolean;
|
||||
tools?: (string | { pluginKey?: string })[];
|
||||
updatedAt?: Date;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export function createPresetMethods(mongoose: typeof import('mongoose')) {
|
||||
/**
|
||||
* Retrieves a single preset by user and presetId.
|
||||
*/
|
||||
async function getPreset(user: string, presetId: string) {
|
||||
try {
|
||||
const Preset = mongoose.models.Preset as Model<IPreset>;
|
||||
return await Preset.findOne({ user, presetId }).lean();
|
||||
} catch (error) {
|
||||
logger.error('[getPreset] Error getting single preset', error);
|
||||
return { message: 'Error getting single preset' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves all presets for a user, sorted by order then updatedAt.
|
||||
*/
|
||||
async function getPresets(user: string, filter: Record<string, unknown> = {}) {
|
||||
try {
|
||||
const Preset = mongoose.models.Preset as Model<IPreset>;
|
||||
const presets = await Preset.find({ ...filter, user }).lean();
|
||||
const defaultValue = 10000;
|
||||
|
||||
presets.sort((a, b) => {
|
||||
const orderA = a.order !== undefined ? a.order : defaultValue;
|
||||
const orderB = b.order !== undefined ? b.order : defaultValue;
|
||||
|
||||
if (orderA !== orderB) {
|
||||
return orderA - orderB;
|
||||
}
|
||||
|
||||
return new Date(b.updatedAt ?? 0).getTime() - new Date(a.updatedAt ?? 0).getTime();
|
||||
});
|
||||
|
||||
return presets;
|
||||
} catch (error) {
|
||||
logger.error('[getPresets] Error getting presets', error);
|
||||
return { message: 'Error retrieving presets' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves a preset. Handles default preset logic and tool normalization.
|
||||
*/
|
||||
async function savePreset(
|
||||
user: string,
|
||||
{
|
||||
presetId,
|
||||
newPresetId,
|
||||
defaultPreset,
|
||||
...preset
|
||||
}: {
|
||||
presetId?: string;
|
||||
newPresetId?: string;
|
||||
defaultPreset?: boolean;
|
||||
[key: string]: unknown;
|
||||
},
|
||||
) {
|
||||
try {
|
||||
const Preset = mongoose.models.Preset as Model<IPreset>;
|
||||
const setter: Record<string, unknown> = { $set: {} };
|
||||
const { user: _unusedUser, ...cleanPreset } = preset;
|
||||
const update: Record<string, unknown> = { presetId, ...cleanPreset };
|
||||
if (preset.tools && Array.isArray(preset.tools)) {
|
||||
update.tools =
|
||||
(preset.tools as Array<string | { pluginKey?: string }>)
|
||||
.map((tool) => (typeof tool === 'object' && tool?.pluginKey ? tool.pluginKey : tool))
|
||||
.filter((toolName) => typeof toolName === 'string') ?? [];
|
||||
}
|
||||
if (newPresetId) {
|
||||
update.presetId = newPresetId;
|
||||
}
|
||||
|
||||
if (defaultPreset) {
|
||||
update.defaultPreset = defaultPreset;
|
||||
update.order = 0;
|
||||
|
||||
const currentDefault = await Preset.findOne({ defaultPreset: true, user });
|
||||
|
||||
if (currentDefault && currentDefault.presetId !== presetId) {
|
||||
await Preset.findByIdAndUpdate(currentDefault._id, {
|
||||
$unset: { defaultPreset: '', order: '' },
|
||||
});
|
||||
}
|
||||
} else if (defaultPreset === false) {
|
||||
update.defaultPreset = undefined;
|
||||
update.order = undefined;
|
||||
setter['$unset'] = { defaultPreset: '', order: '' };
|
||||
}
|
||||
|
||||
setter.$set = update;
|
||||
return await Preset.findOneAndUpdate({ presetId, user }, setter, {
|
||||
new: true,
|
||||
upsert: true,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[savePreset] Error saving preset', error);
|
||||
return { message: 'Error saving preset' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes presets matching the given filter for a user.
|
||||
*/
|
||||
async function deletePresets(user: string, filter: Record<string, unknown> = {}) {
|
||||
const Preset = mongoose.models.Preset as Model<IPreset>;
|
||||
const deleteCount = await Preset.deleteMany({ ...filter, user });
|
||||
return deleteCount;
|
||||
}
|
||||
|
||||
return {
|
||||
getPreset,
|
||||
getPresets,
|
||||
savePreset,
|
||||
deletePresets,
|
||||
};
|
||||
}
|
||||
|
||||
export type PresetMethods = ReturnType<typeof createPresetMethods>;
|
||||
627
packages/data-schemas/src/methods/prompt.spec.ts
Normal file
627
packages/data-schemas/src/methods/prompt.spec.ts
Normal file
|
|
@ -0,0 +1,627 @@
|
|||
import mongoose from 'mongoose';
|
||||
import { ObjectId } from 'mongodb';
|
||||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import {
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
AccessRoleIds,
|
||||
PrincipalType,
|
||||
PermissionBits,
|
||||
} from 'librechat-data-provider';
|
||||
import type { IPromptGroup, AccessRole as TAccessRole, AclEntry as TAclEntry } from '..';
|
||||
import { createAclEntryMethods } from './aclEntry';
|
||||
import { logger, createModels } from '..';
|
||||
import { createMethods } from './index';
|
||||
|
||||
// Disable console for tests
|
||||
logger.silent = true;
|
||||
|
||||
/** Lean user object from .toObject() */
|
||||
type LeanUser = {
|
||||
_id: mongoose.Types.ObjectId | string;
|
||||
name?: string;
|
||||
email: string;
|
||||
role?: string;
|
||||
};
|
||||
|
||||
/** Lean group object from .toObject() */
|
||||
type LeanGroup = {
|
||||
_id: mongoose.Types.ObjectId | string;
|
||||
name: string;
|
||||
description?: string;
|
||||
};
|
||||
|
||||
/** Lean access role object from .toObject() / .lean() */
|
||||
type LeanAccessRole = TAccessRole & { _id: mongoose.Types.ObjectId | string };
|
||||
|
||||
/** Lean ACL entry from .lean() */
|
||||
type LeanAclEntry = TAclEntry & { _id: mongoose.Types.ObjectId | string };
|
||||
|
||||
/** Lean prompt group from .toObject() */
|
||||
type LeanPromptGroup = IPromptGroup & { _id: mongoose.Types.ObjectId | string };
|
||||
|
||||
let Prompt: mongoose.Model<unknown>;
|
||||
let PromptGroup: mongoose.Model<unknown>;
|
||||
let AclEntry: mongoose.Model<unknown>;
|
||||
let AccessRole: mongoose.Model<unknown>;
|
||||
let User: mongoose.Model<unknown>;
|
||||
let Group: mongoose.Model<unknown>;
|
||||
let methods: ReturnType<typeof createMethods>;
|
||||
let aclMethods: ReturnType<typeof createAclEntryMethods>;
|
||||
let testUsers: Record<string, LeanUser>;
|
||||
let testGroups: Record<string, LeanGroup>;
|
||||
let testRoles: Record<string, LeanAccessRole>;
|
||||
|
||||
let mongoServer: MongoMemoryServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
createModels(mongoose);
|
||||
Prompt = mongoose.models.Prompt;
|
||||
PromptGroup = mongoose.models.PromptGroup;
|
||||
AclEntry = mongoose.models.AclEntry;
|
||||
AccessRole = mongoose.models.AccessRole;
|
||||
User = mongoose.models.User;
|
||||
Group = mongoose.models.Group;
|
||||
|
||||
methods = createMethods(mongoose, {
|
||||
removeAllPermissions: async ({ resourceType, resourceId }) => {
|
||||
await AclEntry.deleteMany({ resourceType, resourceId });
|
||||
},
|
||||
});
|
||||
aclMethods = createAclEntryMethods(mongoose);
|
||||
|
||||
await setupTestData();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
async function setupTestData() {
|
||||
testRoles = {
|
||||
viewer: (
|
||||
await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
name: 'Viewer',
|
||||
description: 'Can view promptGroups',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits: PermissionBits.VIEW,
|
||||
})
|
||||
).toObject() as unknown as LeanAccessRole,
|
||||
editor: (
|
||||
await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_EDITOR,
|
||||
name: 'Editor',
|
||||
description: 'Can view and edit promptGroups',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits: PermissionBits.VIEW | PermissionBits.EDIT,
|
||||
})
|
||||
).toObject() as unknown as LeanAccessRole,
|
||||
owner: (
|
||||
await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
name: 'Owner',
|
||||
description: 'Full control over promptGroups',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits:
|
||||
PermissionBits.VIEW | PermissionBits.EDIT | PermissionBits.DELETE | PermissionBits.SHARE,
|
||||
})
|
||||
).toObject() as unknown as LeanAccessRole,
|
||||
};
|
||||
|
||||
testUsers = {
|
||||
owner: (
|
||||
await User.create({
|
||||
name: 'Prompt Owner',
|
||||
email: 'owner@example.com',
|
||||
role: SystemRoles.USER,
|
||||
})
|
||||
).toObject() as unknown as LeanUser,
|
||||
editor: (
|
||||
await User.create({
|
||||
name: 'Prompt Editor',
|
||||
email: 'editor@example.com',
|
||||
role: SystemRoles.USER,
|
||||
})
|
||||
).toObject() as unknown as LeanUser,
|
||||
viewer: (
|
||||
await User.create({
|
||||
name: 'Prompt Viewer',
|
||||
email: 'viewer@example.com',
|
||||
role: SystemRoles.USER,
|
||||
})
|
||||
).toObject() as unknown as LeanUser,
|
||||
admin: (
|
||||
await User.create({
|
||||
name: 'Admin User',
|
||||
email: 'admin@example.com',
|
||||
role: SystemRoles.ADMIN,
|
||||
})
|
||||
).toObject() as unknown as LeanUser,
|
||||
noAccess: (
|
||||
await User.create({
|
||||
name: 'No Access User',
|
||||
email: 'noaccess@example.com',
|
||||
role: SystemRoles.USER,
|
||||
})
|
||||
).toObject() as unknown as LeanUser,
|
||||
};
|
||||
|
||||
testGroups = {
|
||||
editors: (
|
||||
await Group.create({
|
||||
name: 'Prompt Editors',
|
||||
description: 'Group with editor access',
|
||||
})
|
||||
).toObject() as unknown as LeanGroup,
|
||||
viewers: (
|
||||
await Group.create({
|
||||
name: 'Prompt Viewers',
|
||||
description: 'Group with viewer access',
|
||||
})
|
||||
).toObject() as unknown as LeanGroup,
|
||||
};
|
||||
}
|
||||
|
||||
/** Helper: grant permission via direct AclEntry.create */
|
||||
async function grantPermission(params: {
|
||||
principalType: string;
|
||||
principalId: mongoose.Types.ObjectId | string;
|
||||
resourceType: string;
|
||||
resourceId: mongoose.Types.ObjectId | string;
|
||||
accessRoleId: string;
|
||||
grantedBy: mongoose.Types.ObjectId | string;
|
||||
}) {
|
||||
const role = (await AccessRole.findOne({
|
||||
accessRoleId: params.accessRoleId,
|
||||
}).lean()) as LeanAccessRole | null;
|
||||
if (!role) {
|
||||
throw new Error(`AccessRole ${params.accessRoleId} not found`);
|
||||
}
|
||||
return aclMethods.grantPermission(
|
||||
params.principalType,
|
||||
params.principalId,
|
||||
params.resourceType,
|
||||
params.resourceId,
|
||||
role.permBits,
|
||||
params.grantedBy,
|
||||
undefined,
|
||||
role._id,
|
||||
);
|
||||
}
|
||||
|
||||
/** Helper: check permission via getUserPrincipals + hasPermission */
|
||||
async function checkPermission(params: {
|
||||
userId: mongoose.Types.ObjectId | string;
|
||||
resourceType: string;
|
||||
resourceId: mongoose.Types.ObjectId | string;
|
||||
requiredPermission: number;
|
||||
includePublic?: boolean;
|
||||
}) {
|
||||
// getUserPrincipals already includes user, role, groups, and public
|
||||
const principals = await methods.getUserPrincipals({
|
||||
userId: params.userId,
|
||||
});
|
||||
|
||||
// If not including public, filter it out
|
||||
const filteredPrincipals = params.includePublic
|
||||
? principals
|
||||
: principals.filter((p) => p.principalType !== PrincipalType.PUBLIC);
|
||||
|
||||
return aclMethods.hasPermission(
|
||||
filteredPrincipals,
|
||||
params.resourceType,
|
||||
params.resourceId,
|
||||
params.requiredPermission,
|
||||
);
|
||||
}
|
||||
|
||||
describe('Prompt ACL Permissions', () => {
|
||||
describe('Creating Prompts with Permissions', () => {
|
||||
it('should grant owner permissions when creating a prompt', async () => {
|
||||
const testGroup = (
|
||||
await PromptGroup.create({
|
||||
name: 'Test Group',
|
||||
category: 'testing',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new mongoose.Types.ObjectId(),
|
||||
})
|
||||
).toObject() as unknown as LeanPromptGroup;
|
||||
|
||||
const promptData = {
|
||||
prompt: {
|
||||
prompt: 'Test prompt content',
|
||||
name: 'Test Prompt',
|
||||
type: 'text',
|
||||
groupId: testGroup._id,
|
||||
},
|
||||
author: testUsers.owner._id,
|
||||
};
|
||||
|
||||
await methods.savePrompt(promptData);
|
||||
|
||||
// Grant owner permission
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Check ACL entry
|
||||
const aclEntry = (await AclEntry.findOne({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
}).lean()) as LeanAclEntry | null;
|
||||
|
||||
expect(aclEntry).toBeTruthy();
|
||||
expect(aclEntry!.permBits).toBe(testRoles.owner.permBits);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Accessing Prompts', () => {
|
||||
let testPromptGroup: LeanPromptGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
testPromptGroup = (
|
||||
await PromptGroup.create({
|
||||
name: 'Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
})
|
||||
).toObject() as unknown as LeanPromptGroup;
|
||||
|
||||
await Prompt.create({
|
||||
prompt: 'Test prompt for access control',
|
||||
name: 'Access Test Prompt',
|
||||
author: testUsers.owner._id,
|
||||
groupId: testPromptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('owner should have full access to their prompt', async () => {
|
||||
const hasAccess = await checkPermission({
|
||||
userId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(true);
|
||||
|
||||
const canEdit = await checkPermission({
|
||||
userId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
expect(canEdit).toBe(true);
|
||||
});
|
||||
|
||||
it('user with viewer role should only have view access', async () => {
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
const canView = await checkPermission({
|
||||
userId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
const canEdit = await checkPermission({
|
||||
userId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
expect(canView).toBe(true);
|
||||
expect(canEdit).toBe(false);
|
||||
});
|
||||
|
||||
it('user without permissions should have no access', async () => {
|
||||
const hasAccess = await checkPermission({
|
||||
userId: testUsers.noAccess._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(false);
|
||||
});
|
||||
|
||||
it('admin should have access regardless of permissions', async () => {
|
||||
// Admin users should work through normal permission system
|
||||
// The middleware layer handles admin bypass, not the permission service
|
||||
const hasAccess = await checkPermission({
|
||||
userId: testUsers.admin._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
// Without explicit permissions, even admin won't have access at this layer
|
||||
expect(hasAccess).toBe(false);
|
||||
|
||||
// The actual admin bypass happens in the middleware layer
|
||||
});
|
||||
});
|
||||
|
||||
describe('Group-based Access', () => {
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
await User.updateMany({}, { $set: { groups: [] } });
|
||||
});
|
||||
|
||||
it('group members should inherit group permissions', async () => {
|
||||
const testPromptGroup = (
|
||||
await PromptGroup.create({
|
||||
name: 'Group Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
})
|
||||
).toObject() as unknown as LeanPromptGroup;
|
||||
|
||||
// Add user to group
|
||||
await methods.addUserToGroup(testUsers.editor._id, testGroups.editors._id);
|
||||
|
||||
await methods.savePrompt({
|
||||
author: testUsers.owner._id,
|
||||
prompt: {
|
||||
prompt: 'Group test prompt',
|
||||
name: 'Group Test',
|
||||
groupId: testPromptGroup._id,
|
||||
type: 'text',
|
||||
},
|
||||
});
|
||||
|
||||
// Grant edit permissions to the group
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.GROUP,
|
||||
principalId: testGroups.editors._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_EDITOR,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Check if group member has access
|
||||
const hasAccess = await checkPermission({
|
||||
userId: testUsers.editor._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(true);
|
||||
|
||||
// Check that non-member doesn't have access
|
||||
const nonMemberAccess = await checkPermission({
|
||||
userId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
expect(nonMemberAccess).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Public Access', () => {
|
||||
let publicPromptGroup: LeanPromptGroup;
|
||||
let privatePromptGroup: LeanPromptGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
publicPromptGroup = (
|
||||
await PromptGroup.create({
|
||||
name: 'Public Access Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
})
|
||||
).toObject() as unknown as LeanPromptGroup;
|
||||
|
||||
privatePromptGroup = (
|
||||
await PromptGroup.create({
|
||||
name: 'Private Access Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
})
|
||||
).toObject() as unknown as LeanPromptGroup;
|
||||
|
||||
await Prompt.create({
|
||||
prompt: 'Public prompt',
|
||||
name: 'Public',
|
||||
author: testUsers.owner._id,
|
||||
groupId: publicPromptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
await Prompt.create({
|
||||
prompt: 'Private prompt',
|
||||
name: 'Private',
|
||||
author: testUsers.owner._id,
|
||||
groupId: privatePromptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
// Grant public view access
|
||||
await aclMethods.grantPermission(
|
||||
PrincipalType.PUBLIC,
|
||||
null,
|
||||
ResourceType.PROMPTGROUP,
|
||||
publicPromptGroup._id,
|
||||
PermissionBits.VIEW,
|
||||
testUsers.owner._id,
|
||||
);
|
||||
|
||||
// Grant only owner access to private
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: privatePromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('public prompt should be accessible to any user', async () => {
|
||||
const hasAccess = await checkPermission({
|
||||
userId: testUsers.noAccess._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: publicPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
includePublic: true,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(true);
|
||||
});
|
||||
|
||||
it('private prompt should not be accessible to unauthorized users', async () => {
|
||||
const hasAccess = await checkPermission({
|
||||
userId: testUsers.noAccess._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: privatePromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
includePublic: true,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Prompt Deletion', () => {
|
||||
it('should remove ACL entries when prompt is deleted', async () => {
|
||||
const testPromptGroup = (
|
||||
await PromptGroup.create({
|
||||
name: 'Deletion Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
})
|
||||
).toObject() as unknown as LeanPromptGroup;
|
||||
|
||||
const result = await methods.savePrompt({
|
||||
author: testUsers.owner._id,
|
||||
prompt: {
|
||||
prompt: 'To be deleted',
|
||||
name: 'Delete Test',
|
||||
groupId: testPromptGroup._id,
|
||||
type: 'text',
|
||||
},
|
||||
});
|
||||
|
||||
const savedPrompt = result as { prompt?: { _id: mongoose.Types.ObjectId } } | null;
|
||||
if (!savedPrompt?.prompt) {
|
||||
throw new Error('Failed to save prompt');
|
||||
}
|
||||
const testPromptId = savedPrompt.prompt._id;
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Verify ACL entry exists
|
||||
const beforeDelete = await AclEntry.find({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
});
|
||||
expect(beforeDelete).toHaveLength(1);
|
||||
|
||||
// Delete the prompt
|
||||
await methods.deletePrompt({
|
||||
promptId: testPromptId,
|
||||
groupId: testPromptGroup._id,
|
||||
author: testUsers.owner._id,
|
||||
role: SystemRoles.USER,
|
||||
});
|
||||
|
||||
// Verify ACL entries are removed
|
||||
const aclEntries = await AclEntry.find({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
});
|
||||
|
||||
expect(aclEntries).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Backwards Compatibility', () => {
|
||||
it('should handle prompts without ACL entries gracefully', async () => {
|
||||
const promptGroup = (
|
||||
await PromptGroup.create({
|
||||
name: 'Legacy Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
})
|
||||
).toObject() as unknown as LeanPromptGroup;
|
||||
|
||||
const legacyPrompt = (
|
||||
await Prompt.create({
|
||||
prompt: 'Legacy prompt without ACL',
|
||||
name: 'Legacy',
|
||||
author: testUsers.owner._id,
|
||||
groupId: promptGroup._id,
|
||||
type: 'text',
|
||||
})
|
||||
).toObject() as { _id: mongoose.Types.ObjectId };
|
||||
|
||||
const prompt = (await methods.getPrompt({ _id: legacyPrompt._id })) as {
|
||||
_id: mongoose.Types.ObjectId;
|
||||
} | null;
|
||||
expect(prompt).toBeTruthy();
|
||||
expect(String(prompt!._id)).toBe(String(legacyPrompt._id));
|
||||
});
|
||||
});
|
||||
});
|
||||
659
packages/data-schemas/src/methods/prompt.ts
Normal file
659
packages/data-schemas/src/methods/prompt.ts
Normal file
|
|
@ -0,0 +1,659 @@
|
|||
import type { Model, Types } from 'mongoose';
|
||||
import { SystemRoles, ResourceType, SystemCategories } from 'librechat-data-provider';
|
||||
import type { IPrompt, IPromptGroup, IPromptGroupDocument } from '~/types';
|
||||
import { escapeRegExp } from '~/utils/string';
|
||||
import logger from '~/config/winston';
|
||||
|
||||
export interface PromptDeps {
|
||||
/** Removes all ACL permissions for a resource. Injected from PermissionService. */
|
||||
removeAllPermissions: (params: { resourceType: string; resourceId: unknown }) => Promise<void>;
|
||||
}
|
||||
|
||||
export function createPromptMethods(mongoose: typeof import('mongoose'), deps: PromptDeps) {
|
||||
const { ObjectId } = mongoose.Types;
|
||||
|
||||
/**
|
||||
* Batch-fetches production prompts for an array of prompt groups
|
||||
* and attaches them as `productionPrompt` field.
|
||||
*/
|
||||
async function attachProductionPrompts(
|
||||
groups: Array<Record<string, unknown>>,
|
||||
): Promise<Array<Record<string, unknown>>> {
|
||||
const Prompt = mongoose.models.Prompt as Model<IPrompt>;
|
||||
const uniqueIds = [
|
||||
...new Set(groups.map((g) => (g.productionId as Types.ObjectId)?.toString()).filter(Boolean)),
|
||||
];
|
||||
if (uniqueIds.length === 0) {
|
||||
return groups.map((g) => ({ ...g, productionPrompt: null }));
|
||||
}
|
||||
|
||||
const prompts = await Prompt.find({ _id: { $in: uniqueIds } })
|
||||
.select('prompt')
|
||||
.lean();
|
||||
const promptMap = new Map(prompts.map((p) => [p._id.toString(), p]));
|
||||
|
||||
return groups.map((g) => ({
|
||||
...g,
|
||||
productionPrompt: g.productionId
|
||||
? (promptMap.get((g.productionId as Types.ObjectId).toString()) ?? null)
|
||||
: null,
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all prompt groups with filters (no pagination).
|
||||
*/
|
||||
async function getAllPromptGroups(filter: Record<string, unknown>) {
|
||||
try {
|
||||
const PromptGroup = mongoose.models.PromptGroup as Model<IPromptGroupDocument>;
|
||||
const { name, ...query } = filter as {
|
||||
name?: string;
|
||||
category?: string;
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
if (name) {
|
||||
(query as Record<string, unknown>).name = new RegExp(escapeRegExp(name), 'i');
|
||||
}
|
||||
if (!query.category) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||
query.category = '';
|
||||
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||
delete query.category;
|
||||
}
|
||||
|
||||
const groups = await PromptGroup.find(query)
|
||||
.sort({ createdAt: -1 })
|
||||
.select('name oneliner category author authorName createdAt updatedAt command productionId')
|
||||
.lean();
|
||||
return await attachProductionPrompts(groups as unknown as Array<Record<string, unknown>>);
|
||||
} catch (error) {
|
||||
console.error('Error getting all prompt groups', error);
|
||||
return { message: 'Error getting all prompt groups' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get prompt groups with pagination and filters.
|
||||
*/
|
||||
async function getPromptGroups(filter: Record<string, unknown>) {
|
||||
try {
|
||||
const PromptGroup = mongoose.models.PromptGroup as Model<IPromptGroupDocument>;
|
||||
const {
|
||||
pageNumber = 1,
|
||||
pageSize = 10,
|
||||
name,
|
||||
...query
|
||||
} = filter as {
|
||||
pageNumber?: number | string;
|
||||
pageSize?: number | string;
|
||||
name?: string;
|
||||
category?: string;
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
const validatedPageNumber = Math.max(parseInt(String(pageNumber), 10), 1);
|
||||
const validatedPageSize = Math.max(parseInt(String(pageSize), 10), 1);
|
||||
|
||||
if (name) {
|
||||
(query as Record<string, unknown>).name = new RegExp(escapeRegExp(name), 'i');
|
||||
}
|
||||
if (!query.category) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||
query.category = '';
|
||||
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||
delete query.category;
|
||||
}
|
||||
|
||||
const skip = (validatedPageNumber - 1) * validatedPageSize;
|
||||
const limit = validatedPageSize;
|
||||
|
||||
const [groups, totalPromptGroups] = await Promise.all([
|
||||
PromptGroup.find(query)
|
||||
.sort({ createdAt: -1 })
|
||||
.skip(skip)
|
||||
.limit(limit)
|
||||
.select(
|
||||
'name numberOfGenerations oneliner category productionId author authorName createdAt updatedAt',
|
||||
)
|
||||
.lean(),
|
||||
PromptGroup.countDocuments(query),
|
||||
]);
|
||||
|
||||
const promptGroups = await attachProductionPrompts(
|
||||
groups as unknown as Array<Record<string, unknown>>,
|
||||
);
|
||||
|
||||
return {
|
||||
promptGroups,
|
||||
pageNumber: validatedPageNumber.toString(),
|
||||
pageSize: validatedPageSize.toString(),
|
||||
pages: Math.ceil(totalPromptGroups / validatedPageSize).toString(),
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a prompt group and its prompts, cleaning up ACL permissions.
|
||||
*/
|
||||
async function deletePromptGroup({
|
||||
_id,
|
||||
author,
|
||||
role,
|
||||
}: {
|
||||
_id: string;
|
||||
author?: string;
|
||||
role?: string;
|
||||
}) {
|
||||
const PromptGroup = mongoose.models.PromptGroup as Model<IPromptGroupDocument>;
|
||||
const Prompt = mongoose.models.Prompt as Model<IPrompt>;
|
||||
|
||||
const query: Record<string, unknown> = { _id };
|
||||
const groupQuery: Record<string, unknown> = { groupId: new ObjectId(_id) };
|
||||
|
||||
if (author && role !== SystemRoles.ADMIN) {
|
||||
query.author = author;
|
||||
groupQuery.author = author;
|
||||
}
|
||||
|
||||
const response = await PromptGroup.deleteOne(query);
|
||||
|
||||
if (!response || response.deletedCount === 0) {
|
||||
throw new Error('Prompt group not found');
|
||||
}
|
||||
|
||||
await Prompt.deleteMany(groupQuery);
|
||||
|
||||
try {
|
||||
await deps.removeAllPermissions({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: _id,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error removing promptGroup permissions:', error);
|
||||
}
|
||||
|
||||
return { message: 'Prompt group deleted successfully' };
|
||||
}
|
||||
|
||||
/**
|
||||
* Get prompt groups by accessible IDs with optional cursor-based pagination.
|
||||
*/
|
||||
async function getListPromptGroupsByAccess({
|
||||
accessibleIds = [],
|
||||
otherParams = {},
|
||||
limit = null,
|
||||
after = null,
|
||||
}: {
|
||||
accessibleIds?: Types.ObjectId[];
|
||||
otherParams?: Record<string, unknown>;
|
||||
limit?: number | null;
|
||||
after?: string | null;
|
||||
}) {
|
||||
const PromptGroup = mongoose.models.PromptGroup as Model<IPromptGroupDocument>;
|
||||
const isPaginated = limit !== null && limit !== undefined;
|
||||
const normalizedLimit = isPaginated
|
||||
? Math.min(Math.max(1, parseInt(String(limit)) || 20), 100)
|
||||
: null;
|
||||
|
||||
const baseQuery: Record<string, unknown> = {
|
||||
...otherParams,
|
||||
_id: { $in: accessibleIds },
|
||||
};
|
||||
|
||||
if (after && typeof after === 'string' && after !== 'undefined' && after !== 'null') {
|
||||
try {
|
||||
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
|
||||
const { updatedAt, _id } = cursor;
|
||||
|
||||
const cursorCondition = {
|
||||
$or: [
|
||||
{ updatedAt: { $lt: new Date(updatedAt) } },
|
||||
{ updatedAt: new Date(updatedAt), _id: { $gt: new ObjectId(_id) } },
|
||||
],
|
||||
};
|
||||
|
||||
if (Object.keys(baseQuery).length > 0) {
|
||||
baseQuery.$and = [{ ...baseQuery }, cursorCondition];
|
||||
Object.keys(baseQuery).forEach((key) => {
|
||||
if (key !== '$and') {
|
||||
delete baseQuery[key];
|
||||
}
|
||||
});
|
||||
} else {
|
||||
Object.assign(baseQuery, cursorCondition);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Invalid cursor:', (error as Error).message);
|
||||
}
|
||||
}
|
||||
|
||||
const findQuery = PromptGroup.find(baseQuery)
|
||||
.sort({ updatedAt: -1, _id: 1 })
|
||||
.select(
|
||||
'name numberOfGenerations oneliner category productionId author authorName createdAt updatedAt',
|
||||
);
|
||||
|
||||
if (isPaginated && normalizedLimit) {
|
||||
findQuery.limit(normalizedLimit + 1);
|
||||
}
|
||||
|
||||
const groups = await findQuery.lean();
|
||||
const promptGroups = await attachProductionPrompts(
|
||||
groups as unknown as Array<Record<string, unknown>>,
|
||||
);
|
||||
|
||||
const hasMore = isPaginated && normalizedLimit ? promptGroups.length > normalizedLimit : false;
|
||||
const data = (
|
||||
isPaginated && normalizedLimit ? promptGroups.slice(0, normalizedLimit) : promptGroups
|
||||
).map((group) => {
|
||||
if (group.author) {
|
||||
group.author = (group.author as Types.ObjectId).toString();
|
||||
}
|
||||
return group;
|
||||
});
|
||||
|
||||
let nextCursor: string | null = null;
|
||||
if (isPaginated && hasMore && data.length > 0 && normalizedLimit) {
|
||||
const lastGroup = promptGroups[normalizedLimit - 1] as Record<string, unknown>;
|
||||
nextCursor = Buffer.from(
|
||||
JSON.stringify({
|
||||
updatedAt: (lastGroup.updatedAt as Date).toISOString(),
|
||||
_id: (lastGroup._id as Types.ObjectId).toString(),
|
||||
}),
|
||||
).toString('base64');
|
||||
}
|
||||
|
||||
return {
|
||||
object: 'list' as const,
|
||||
data,
|
||||
first_id: data.length > 0 ? (data[0]._id as Types.ObjectId).toString() : null,
|
||||
last_id: data.length > 0 ? (data[data.length - 1]._id as Types.ObjectId).toString() : null,
|
||||
has_more: hasMore,
|
||||
after: nextCursor,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a prompt and its respective group.
|
||||
*/
|
||||
async function createPromptGroup(saveData: {
|
||||
prompt: Record<string, unknown>;
|
||||
group: Record<string, unknown>;
|
||||
author: string;
|
||||
authorName: string;
|
||||
}) {
|
||||
try {
|
||||
const PromptGroup = mongoose.models.PromptGroup as Model<IPromptGroupDocument>;
|
||||
const Prompt = mongoose.models.Prompt as Model<IPrompt>;
|
||||
const { prompt, group, author, authorName } = saveData;
|
||||
|
||||
let newPromptGroup = await PromptGroup.findOneAndUpdate(
|
||||
{ ...group, author, authorName, productionId: null },
|
||||
{ $setOnInsert: { ...group, author, authorName, productionId: null } },
|
||||
{ new: true, upsert: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
const newPrompt = await Prompt.findOneAndUpdate(
|
||||
{ ...prompt, author, groupId: newPromptGroup!._id },
|
||||
{ $setOnInsert: { ...prompt, author, groupId: newPromptGroup!._id } },
|
||||
{ new: true, upsert: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
newPromptGroup = (await PromptGroup.findByIdAndUpdate(
|
||||
newPromptGroup!._id,
|
||||
{ productionId: newPrompt!._id },
|
||||
{ new: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec())!;
|
||||
|
||||
return {
|
||||
prompt: newPrompt,
|
||||
group: {
|
||||
...newPromptGroup,
|
||||
productionPrompt: { prompt: (newPrompt as unknown as IPrompt).prompt },
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error saving prompt group', error);
|
||||
throw new Error('Error saving prompt group');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Save a prompt.
|
||||
*/
|
||||
async function savePrompt(saveData: {
|
||||
prompt: Record<string, unknown>;
|
||||
author: string | Types.ObjectId;
|
||||
}) {
|
||||
try {
|
||||
const Prompt = mongoose.models.Prompt as Model<IPrompt>;
|
||||
const { prompt, author } = saveData;
|
||||
const newPromptData = { ...prompt, author };
|
||||
|
||||
let newPrompt;
|
||||
try {
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
} catch (error: unknown) {
|
||||
if ((error as Error)?.message?.includes('groupId_1_version_1')) {
|
||||
await Prompt.db.collection('prompts').dropIndex('groupId_1_version_1');
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
}
|
||||
|
||||
return { prompt: newPrompt };
|
||||
} catch (error) {
|
||||
logger.error('Error saving prompt', error);
|
||||
return { message: 'Error saving prompt' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get prompts by filter.
|
||||
*/
|
||||
async function getPrompts(filter: Record<string, unknown>) {
|
||||
try {
|
||||
const Prompt = mongoose.models.Prompt as Model<IPrompt>;
|
||||
return await Prompt.find(filter).sort({ createdAt: -1 }).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompts', error);
|
||||
return { message: 'Error getting prompts' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a single prompt by filter.
|
||||
*/
|
||||
async function getPrompt(filter: Record<string, unknown>) {
|
||||
try {
|
||||
const Prompt = mongoose.models.Prompt as Model<IPrompt>;
|
||||
if (filter.groupId) {
|
||||
filter.groupId = new ObjectId(filter.groupId as string);
|
||||
}
|
||||
return await Prompt.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt', error);
|
||||
return { message: 'Error getting prompt' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get random prompt groups from distinct categories.
|
||||
*/
|
||||
async function getRandomPromptGroups(filter: { skip: number | string; limit: number | string }) {
|
||||
try {
|
||||
const PromptGroup = mongoose.models.PromptGroup as Model<IPromptGroupDocument>;
|
||||
const categories = await PromptGroup.distinct('category', { category: { $ne: '' } });
|
||||
|
||||
for (let i = categories.length - 1; i > 0; i--) {
|
||||
const j = Math.floor(Math.random() * (i + 1));
|
||||
[categories[i], categories[j]] = [categories[j], categories[i]];
|
||||
}
|
||||
|
||||
const skip = +filter.skip;
|
||||
const limit = +filter.limit;
|
||||
const selectedCategories = categories.slice(skip, skip + limit);
|
||||
|
||||
if (selectedCategories.length === 0) {
|
||||
return { prompts: [] };
|
||||
}
|
||||
|
||||
const groups = await PromptGroup.find({ category: { $in: selectedCategories } }).lean();
|
||||
|
||||
const groupByCategory = new Map<string, unknown>();
|
||||
for (const group of groups) {
|
||||
if (!groupByCategory.has(group.category)) {
|
||||
groupByCategory.set(group.category, group);
|
||||
}
|
||||
}
|
||||
|
||||
const prompts = selectedCategories
|
||||
.map((cat: string) => groupByCategory.get(cat))
|
||||
.filter(Boolean);
|
||||
|
||||
return { prompts };
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get prompt groups with populated prompts.
|
||||
*/
|
||||
async function getPromptGroupsWithPrompts(filter: Record<string, unknown>) {
|
||||
try {
|
||||
const PromptGroup = mongoose.models.PromptGroup as Model<IPromptGroupDocument>;
|
||||
return await PromptGroup.findOne(filter)
|
||||
.populate({
|
||||
path: 'prompts',
|
||||
select: '-_id -__v -user',
|
||||
})
|
||||
.select('-_id -__v -user')
|
||||
.lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a single prompt group by filter.
|
||||
*/
|
||||
async function getPromptGroup(filter: Record<string, unknown>) {
|
||||
try {
|
||||
const PromptGroup = mongoose.models.PromptGroup as Model<IPromptGroupDocument>;
|
||||
return await PromptGroup.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt group', error);
|
||||
return { message: 'Error getting prompt group' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a prompt, potentially removing the group if it's the last prompt.
|
||||
*/
|
||||
async function deletePrompt({
|
||||
promptId,
|
||||
groupId,
|
||||
author,
|
||||
role,
|
||||
}: {
|
||||
promptId: string | Types.ObjectId;
|
||||
groupId: string | Types.ObjectId;
|
||||
author: string | Types.ObjectId;
|
||||
role?: string;
|
||||
}) {
|
||||
const Prompt = mongoose.models.Prompt as Model<IPrompt>;
|
||||
const PromptGroup = mongoose.models.PromptGroup as Model<IPromptGroupDocument>;
|
||||
|
||||
const query: Record<string, unknown> = { _id: promptId, groupId, author };
|
||||
if (role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const { deletedCount } = await Prompt.deleteOne(query);
|
||||
if (deletedCount === 0) {
|
||||
throw new Error('Failed to delete the prompt');
|
||||
}
|
||||
|
||||
const remainingPrompts = await Prompt.find({ groupId })
|
||||
.select('_id')
|
||||
.sort({ createdAt: 1 })
|
||||
.lean();
|
||||
|
||||
if (remainingPrompts.length === 0) {
|
||||
try {
|
||||
await deps.removeAllPermissions({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: groupId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error removing promptGroup permissions:', error);
|
||||
}
|
||||
|
||||
await PromptGroup.deleteOne({ _id: groupId });
|
||||
|
||||
return {
|
||||
prompt: 'Prompt deleted successfully',
|
||||
promptGroup: {
|
||||
message: 'Prompt group deleted successfully',
|
||||
id: groupId,
|
||||
},
|
||||
};
|
||||
} else {
|
||||
const promptGroup = (await PromptGroup.findById(
|
||||
groupId,
|
||||
).lean()) as unknown as IPromptGroup | null;
|
||||
if (promptGroup && promptGroup.productionId?.toString() === promptId.toString()) {
|
||||
await PromptGroup.updateOne(
|
||||
{ _id: groupId },
|
||||
{ productionId: remainingPrompts[remainingPrompts.length - 1]._id },
|
||||
);
|
||||
}
|
||||
|
||||
return { prompt: 'Prompt deleted successfully' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete all prompts and prompt groups created by a specific user.
|
||||
*/
|
||||
async function deleteUserPrompts(userId: string) {
|
||||
try {
|
||||
const PromptGroup = mongoose.models.PromptGroup as Model<IPromptGroupDocument>;
|
||||
const Prompt = mongoose.models.Prompt as Model<IPrompt>;
|
||||
const AclEntry = mongoose.models.AclEntry;
|
||||
|
||||
const promptGroups = (await getAllPromptGroups({ author: new ObjectId(userId) })) as Array<
|
||||
Record<string, unknown>
|
||||
>;
|
||||
|
||||
if (!Array.isArray(promptGroups) || promptGroups.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const groupIds = promptGroups.map((group) => group._id as Types.ObjectId);
|
||||
|
||||
await AclEntry.deleteMany({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: { $in: groupIds },
|
||||
});
|
||||
|
||||
await PromptGroup.deleteMany({ author: new ObjectId(userId) });
|
||||
await Prompt.deleteMany({ author: new ObjectId(userId) });
|
||||
} catch (error) {
|
||||
logger.error('[deleteUserPrompts] General error:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a prompt group.
|
||||
*/
|
||||
async function updatePromptGroup(filter: Record<string, unknown>, data: Record<string, unknown>) {
|
||||
try {
|
||||
const PromptGroup = mongoose.models.PromptGroup as Model<IPromptGroupDocument>;
|
||||
const updateOps = {};
|
||||
const updateData = { ...data, ...updateOps };
|
||||
const updatedDoc = await PromptGroup.findOneAndUpdate(filter, updateData, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
});
|
||||
|
||||
if (!updatedDoc) {
|
||||
throw new Error('Prompt group not found');
|
||||
}
|
||||
|
||||
return updatedDoc;
|
||||
} catch (error) {
|
||||
logger.error('Error updating prompt group', error);
|
||||
return { message: 'Error updating prompt group' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Make a prompt the production prompt for its group.
|
||||
*/
|
||||
async function makePromptProduction(promptId: string) {
|
||||
try {
|
||||
const Prompt = mongoose.models.Prompt as Model<IPrompt>;
|
||||
const PromptGroup = mongoose.models.PromptGroup as Model<IPromptGroupDocument>;
|
||||
|
||||
const prompt = await Prompt.findById(promptId).lean();
|
||||
|
||||
if (!prompt) {
|
||||
throw new Error('Prompt not found');
|
||||
}
|
||||
|
||||
await PromptGroup.findByIdAndUpdate(
|
||||
prompt.groupId,
|
||||
{ productionId: prompt._id },
|
||||
{ new: true },
|
||||
)
|
||||
.lean()
|
||||
.exec();
|
||||
|
||||
return { message: 'Prompt production made successfully' };
|
||||
} catch (error) {
|
||||
logger.error('Error making prompt production', error);
|
||||
return { message: 'Error making prompt production' };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update prompt labels.
|
||||
*/
|
||||
async function updatePromptLabels(_id: string, labels: unknown) {
|
||||
try {
|
||||
const Prompt = mongoose.models.Prompt as Model<IPrompt>;
|
||||
const response = await Prompt.updateOne({ _id }, { $set: { labels } });
|
||||
if (response.matchedCount === 0) {
|
||||
return { message: 'Prompt not found' };
|
||||
}
|
||||
return { message: 'Prompt labels updated successfully' };
|
||||
} catch (error) {
|
||||
logger.error('Error updating prompt labels', error);
|
||||
return { message: 'Error updating prompt labels' };
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
getPromptGroups,
|
||||
deletePromptGroup,
|
||||
getAllPromptGroups,
|
||||
getListPromptGroupsByAccess,
|
||||
createPromptGroup,
|
||||
savePrompt,
|
||||
getPrompts,
|
||||
getPrompt,
|
||||
getRandomPromptGroups,
|
||||
getPromptGroupsWithPrompts,
|
||||
getPromptGroup,
|
||||
deletePrompt,
|
||||
deleteUserPrompts,
|
||||
updatePromptGroup,
|
||||
makePromptProduction,
|
||||
updatePromptLabels,
|
||||
};
|
||||
}
|
||||
|
||||
export type PromptMethods = ReturnType<typeof createPromptMethods>;
|
||||
411
packages/data-schemas/src/methods/role.methods.spec.ts
Normal file
411
packages/data-schemas/src/methods/role.methods.spec.ts
Normal file
|
|
@ -0,0 +1,411 @@
|
|||
import mongoose from 'mongoose';
|
||||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import { SystemRoles, Permissions, roleDefaults, PermissionTypes } from 'librechat-data-provider';
|
||||
import type { IRole, RolePermissions } from '..';
|
||||
import { createRoleMethods } from './role';
|
||||
import { createModels } from '../models';
|
||||
|
||||
const mockCache = {
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
del: jest.fn(),
|
||||
};
|
||||
|
||||
const mockGetCache = jest.fn().mockReturnValue(mockCache);
|
||||
|
||||
let Role: mongoose.Model<IRole>;
|
||||
let getRoleByName: ReturnType<typeof createRoleMethods>['getRoleByName'];
|
||||
let updateAccessPermissions: ReturnType<typeof createRoleMethods>['updateAccessPermissions'];
|
||||
let initializeRoles: ReturnType<typeof createRoleMethods>['initializeRoles'];
|
||||
let mongoServer: MongoMemoryServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
createModels(mongoose);
|
||||
Role = mongoose.models.Role;
|
||||
const methods = createRoleMethods(mongoose, { getCache: mockGetCache });
|
||||
getRoleByName = methods.getRoleByName;
|
||||
updateAccessPermissions = methods.updateAccessPermissions;
|
||||
initializeRoles = methods.initializeRoles;
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await Role.deleteMany({});
|
||||
mockGetCache.mockClear();
|
||||
mockCache.get.mockClear();
|
||||
mockCache.set.mockClear();
|
||||
mockCache.del.mockClear();
|
||||
});
|
||||
|
||||
describe('updateAccessPermissions', () => {
|
||||
it('should update permissions when changes are needed', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARE: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARE: true,
|
||||
},
|
||||
});
|
||||
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARE: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should not update permissions when no changes are needed', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARE: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARE: false,
|
||||
},
|
||||
});
|
||||
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARE: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle non-existent roles', async () => {
|
||||
await updateAccessPermissions('NON_EXISTENT_ROLE', {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true },
|
||||
});
|
||||
const role = await Role.findOne({ name: 'NON_EXISTENT_ROLE' });
|
||||
expect(role).toBeNull();
|
||||
});
|
||||
|
||||
it('should update only specified permissions', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARE: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { SHARE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARE: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle partial updates', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARE: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { USE: false },
|
||||
});
|
||||
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARE: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('should update multiple permission types at once', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARE: false },
|
||||
[PermissionTypes.BOOKMARKS]: { USE: true },
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { USE: false, SHARE: true },
|
||||
[PermissionTypes.BOOKMARKS]: { USE: false },
|
||||
});
|
||||
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARE: true,
|
||||
});
|
||||
expect(updatedRole.permissions[PermissionTypes.BOOKMARKS]).toEqual({ USE: false });
|
||||
});
|
||||
|
||||
it('should handle updates for a single permission type', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARE: false },
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { USE: false, SHARE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARE: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should update MULTI_CONVO permissions', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: false },
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true });
|
||||
});
|
||||
|
||||
it('should update MULTI_CONVO permissions along with other permission types', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARE: false },
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: false },
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { SHARE: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARE: true,
|
||||
});
|
||||
expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true });
|
||||
});
|
||||
|
||||
it('should not update MULTI_CONVO permissions when no changes are needed', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true });
|
||||
});
|
||||
});
|
||||
|
||||
describe('initializeRoles', () => {
|
||||
beforeEach(async () => {
|
||||
await Role.deleteMany({});
|
||||
});
|
||||
|
||||
it('should create default roles if they do not exist', async () => {
|
||||
await initializeRoles();
|
||||
|
||||
const adminRole = await getRoleByName(SystemRoles.ADMIN);
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
|
||||
expect(adminRole).toBeTruthy();
|
||||
expect(userRole).toBeTruthy();
|
||||
|
||||
// Check if all permission types exist in the permissions field
|
||||
Object.values(PermissionTypes).forEach((permType) => {
|
||||
expect(adminRole.permissions[permType]).toBeDefined();
|
||||
expect(userRole.permissions[permType]).toBeDefined();
|
||||
});
|
||||
|
||||
// Example: Check default values for ADMIN role
|
||||
expect(adminRole.permissions[PermissionTypes.PROMPTS]?.SHARE).toBe(true);
|
||||
expect(adminRole.permissions[PermissionTypes.BOOKMARKS]?.USE).toBe(true);
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS]?.CREATE).toBe(true);
|
||||
});
|
||||
|
||||
it('should not modify existing permissions for existing roles', async () => {
|
||||
const customUserRole = {
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.SHARE]: true,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
},
|
||||
};
|
||||
|
||||
await new Role(customUserRole).save();
|
||||
await initializeRoles();
|
||||
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(userRole.permissions[PermissionTypes.PROMPTS]).toEqual(
|
||||
customUserRole.permissions[PermissionTypes.PROMPTS],
|
||||
);
|
||||
expect(userRole.permissions[PermissionTypes.BOOKMARKS]).toEqual(
|
||||
customUserRole.permissions[PermissionTypes.BOOKMARKS],
|
||||
);
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS]).toBeDefined();
|
||||
});
|
||||
|
||||
it('should add new permission types to existing roles', async () => {
|
||||
const partialUserRole = {
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]:
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.PROMPTS],
|
||||
[PermissionTypes.BOOKMARKS]:
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.BOOKMARKS],
|
||||
},
|
||||
};
|
||||
|
||||
await new Role(partialUserRole).save();
|
||||
await initializeRoles();
|
||||
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS]?.CREATE).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS]?.USE).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS]?.SHARE).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle multiple runs without duplicating or modifying data', async () => {
|
||||
await initializeRoles();
|
||||
await initializeRoles();
|
||||
|
||||
const adminRoles = await Role.find({ name: SystemRoles.ADMIN });
|
||||
const userRoles = await Role.find({ name: SystemRoles.USER });
|
||||
|
||||
expect(adminRoles).toHaveLength(1);
|
||||
expect(userRoles).toHaveLength(1);
|
||||
|
||||
const adminPerms = adminRoles[0].toObject().permissions as RolePermissions;
|
||||
const userPerms = userRoles[0].toObject().permissions as RolePermissions;
|
||||
Object.values(PermissionTypes).forEach((permType) => {
|
||||
expect(adminPerms[permType]).toBeDefined();
|
||||
expect(userPerms[permType]).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
it('should update roles with missing permission types from roleDefaults', async () => {
|
||||
const partialAdminRole = {
|
||||
name: SystemRoles.ADMIN,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARE]: false,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]:
|
||||
roleDefaults[SystemRoles.ADMIN].permissions[PermissionTypes.BOOKMARKS],
|
||||
},
|
||||
};
|
||||
|
||||
await new Role(partialAdminRole).save();
|
||||
await initializeRoles();
|
||||
|
||||
const adminRole = await getRoleByName(SystemRoles.ADMIN);
|
||||
expect(adminRole.permissions[PermissionTypes.PROMPTS]).toEqual(
|
||||
partialAdminRole.permissions[PermissionTypes.PROMPTS],
|
||||
);
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS]?.CREATE).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS]?.USE).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS]?.SHARE).toBeDefined();
|
||||
});
|
||||
|
||||
it('should include MULTI_CONVO permissions when creating default roles', async () => {
|
||||
await initializeRoles();
|
||||
|
||||
const adminRole = await getRoleByName(SystemRoles.ADMIN);
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
|
||||
expect(adminRole.permissions[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.MULTI_CONVO]?.USE).toBe(
|
||||
roleDefaults[SystemRoles.ADMIN].permissions[PermissionTypes.MULTI_CONVO].USE,
|
||||
);
|
||||
expect(userRole.permissions[PermissionTypes.MULTI_CONVO]?.USE).toBe(
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.MULTI_CONVO].USE,
|
||||
);
|
||||
});
|
||||
|
||||
it('should add MULTI_CONVO permissions to existing roles without them', async () => {
|
||||
const partialUserRole = {
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]:
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.PROMPTS],
|
||||
[PermissionTypes.BOOKMARKS]:
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.BOOKMARKS],
|
||||
},
|
||||
};
|
||||
|
||||
await new Role(partialUserRole).save();
|
||||
await initializeRoles();
|
||||
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(userRole.permissions[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.MULTI_CONVO]?.USE).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
|
@ -1,7 +1,22 @@
|
|||
import { roleDefaults, SystemRoles } from 'librechat-data-provider';
|
||||
import {
|
||||
CacheKeys,
|
||||
SystemRoles,
|
||||
roleDefaults,
|
||||
permissionsSchema,
|
||||
removeNullishValues,
|
||||
} from 'librechat-data-provider';
|
||||
import type { IRole } from '~/types';
|
||||
import logger from '~/config/winston';
|
||||
|
||||
// Factory function that takes mongoose instance and returns the methods
|
||||
export function createRoleMethods(mongoose: typeof import('mongoose')) {
|
||||
export interface RoleDeps {
|
||||
/** Returns a cache store for the given key. Injected from getLogStores. */
|
||||
getCache?: (key: string) => {
|
||||
get: (k: string) => Promise<unknown>;
|
||||
set: (k: string, v: unknown) => Promise<void>;
|
||||
};
|
||||
}
|
||||
|
||||
export function createRoleMethods(mongoose: typeof import('mongoose'), deps: RoleDeps = {}) {
|
||||
/**
|
||||
* Initialize default roles in the system.
|
||||
* Creates the default roles (ADMIN, USER) if they don't exist in the database.
|
||||
|
|
@ -30,18 +45,262 @@ export function createRoleMethods(mongoose: typeof import('mongoose')) {
|
|||
}
|
||||
|
||||
/**
|
||||
* List all roles in the system (for testing purposes)
|
||||
* Returns an array of all roles with their names and permissions
|
||||
* List all roles in the system.
|
||||
*/
|
||||
async function listRoles() {
|
||||
const Role = mongoose.models.Role;
|
||||
return await Role.find({}).select('name permissions').lean();
|
||||
}
|
||||
|
||||
// Return all methods you want to expose
|
||||
/**
|
||||
* Retrieve a role by name and convert the found role document to a plain object.
|
||||
* If the role with the given name doesn't exist and the name is a system defined role,
|
||||
* create it and return the lean version.
|
||||
*/
|
||||
async function getRoleByName(roleName: string, fieldsToSelect: string | string[] | null = null) {
|
||||
const cache = deps.getCache?.(CacheKeys.ROLES);
|
||||
try {
|
||||
if (cache) {
|
||||
const cachedRole = await cache.get(roleName);
|
||||
if (cachedRole) {
|
||||
return cachedRole as IRole;
|
||||
}
|
||||
}
|
||||
const Role = mongoose.models.Role;
|
||||
let query = Role.findOne({ name: roleName });
|
||||
if (fieldsToSelect) {
|
||||
query = query.select(fieldsToSelect);
|
||||
}
|
||||
const role = await query.lean().exec();
|
||||
|
||||
if (!role && SystemRoles[roleName as keyof typeof SystemRoles]) {
|
||||
const newRole = await new Role(roleDefaults[roleName as keyof typeof roleDefaults]).save();
|
||||
if (cache) {
|
||||
await cache.set(roleName, newRole);
|
||||
}
|
||||
return newRole.toObject() as IRole;
|
||||
}
|
||||
if (cache) {
|
||||
await cache.set(roleName, role);
|
||||
}
|
||||
return role as unknown as IRole;
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to retrieve or create role: ${(error as Error).message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update role values by name.
|
||||
*/
|
||||
async function updateRoleByName(roleName: string, updates: Partial<IRole>) {
|
||||
const cache = deps.getCache?.(CacheKeys.ROLES);
|
||||
try {
|
||||
const Role = mongoose.models.Role;
|
||||
const role = await Role.findOneAndUpdate(
|
||||
{ name: roleName },
|
||||
{ $set: updates },
|
||||
{ new: true, lean: true },
|
||||
)
|
||||
.select('-__v')
|
||||
.lean()
|
||||
.exec();
|
||||
if (cache) {
|
||||
await cache.set(roleName, role);
|
||||
}
|
||||
return role as unknown as IRole;
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to update role: ${(error as Error).message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates access permissions for a specific role and multiple permission types.
|
||||
*/
|
||||
async function updateAccessPermissions(
|
||||
roleName: string,
|
||||
permissionsUpdate: Record<string, Record<string, boolean>>,
|
||||
roleData?: IRole,
|
||||
) {
|
||||
const updates: Record<string, Record<string, boolean>> = {};
|
||||
for (const [permissionType, permissions] of Object.entries(permissionsUpdate)) {
|
||||
if (
|
||||
permissionsSchema.shape &&
|
||||
permissionsSchema.shape[permissionType as keyof typeof permissionsSchema.shape]
|
||||
) {
|
||||
updates[permissionType] = removeNullishValues(permissions) as Record<string, boolean>;
|
||||
}
|
||||
}
|
||||
if (!Object.keys(updates).length) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const role = roleData ?? (await getRoleByName(roleName));
|
||||
if (!role) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentPermissions =
|
||||
((role as unknown as Record<string, unknown>).permissions as Record<
|
||||
string,
|
||||
Record<string, boolean>
|
||||
>) || {};
|
||||
const updatedPermissions: Record<string, Record<string, boolean>> = { ...currentPermissions };
|
||||
let hasChanges = false;
|
||||
|
||||
const unsetFields: Record<string, number> = {};
|
||||
const permissionTypes = Object.keys(permissionsSchema.shape || {});
|
||||
for (const permType of permissionTypes) {
|
||||
if (
|
||||
(role as unknown as Record<string, unknown>)[permType] &&
|
||||
typeof (role as unknown as Record<string, unknown>)[permType] === 'object'
|
||||
) {
|
||||
logger.info(
|
||||
`Migrating '${roleName}' role from old schema: found '${permType}' at top level`,
|
||||
);
|
||||
|
||||
updatedPermissions[permType] = {
|
||||
...updatedPermissions[permType],
|
||||
...((role as unknown as Record<string, unknown>)[permType] as Record<string, boolean>),
|
||||
};
|
||||
|
||||
unsetFields[permType] = 1;
|
||||
hasChanges = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (const [permissionType, permissions] of Object.entries(updates)) {
|
||||
const currentTypePermissions = currentPermissions[permissionType] || {};
|
||||
updatedPermissions[permissionType] = { ...currentTypePermissions };
|
||||
|
||||
for (const [permission, value] of Object.entries(permissions)) {
|
||||
if (currentTypePermissions[permission] !== value) {
|
||||
updatedPermissions[permissionType][permission] = value;
|
||||
hasChanges = true;
|
||||
logger.info(
|
||||
`Updating '${roleName}' role permission '${permissionType}' '${permission}' from ${currentTypePermissions[permission]} to: ${value}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hasChanges) {
|
||||
const Role = mongoose.models.Role;
|
||||
const updateObj = { permissions: updatedPermissions };
|
||||
|
||||
if (Object.keys(unsetFields).length > 0) {
|
||||
logger.info(
|
||||
`Unsetting old schema fields for '${roleName}' role: ${Object.keys(unsetFields).join(', ')}`,
|
||||
);
|
||||
|
||||
try {
|
||||
await Role.updateOne(
|
||||
{ name: roleName },
|
||||
{
|
||||
$set: updateObj,
|
||||
$unset: unsetFields,
|
||||
},
|
||||
);
|
||||
|
||||
const cache = deps.getCache?.(CacheKeys.ROLES);
|
||||
const updatedRole = await Role.findOne({ name: roleName }).select('-__v').lean().exec();
|
||||
if (cache) {
|
||||
await cache.set(roleName, updatedRole);
|
||||
}
|
||||
|
||||
logger.info(`Updated role '${roleName}' and removed old schema fields`);
|
||||
} catch (updateError) {
|
||||
logger.error(`Error during role migration update: ${(updateError as Error).message}`);
|
||||
throw updateError;
|
||||
}
|
||||
} else {
|
||||
await updateRoleByName(roleName, updateObj as unknown as Partial<IRole>);
|
||||
}
|
||||
|
||||
logger.info(`Updated '${roleName}' role permissions`);
|
||||
} else {
|
||||
logger.info(`No changes needed for '${roleName}' role permissions`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to update ${roleName} role permissions:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Migrates roles from old schema to new schema structure.
|
||||
*/
|
||||
async function migrateRoleSchema(roleName?: string): Promise<number> {
|
||||
try {
|
||||
const Role = mongoose.models.Role;
|
||||
let roles;
|
||||
if (roleName) {
|
||||
const role = await Role.findOne({ name: roleName });
|
||||
roles = role ? [role] : [];
|
||||
} else {
|
||||
roles = await Role.find({});
|
||||
}
|
||||
|
||||
logger.info(`Migrating ${roles.length} roles to new schema structure`);
|
||||
let migratedCount = 0;
|
||||
|
||||
for (const role of roles) {
|
||||
const permissionTypes = Object.keys(permissionsSchema.shape || {});
|
||||
const unsetFields: Record<string, number> = {};
|
||||
let hasOldSchema = false;
|
||||
|
||||
for (const permType of permissionTypes) {
|
||||
if (role[permType] && typeof role[permType] === 'object') {
|
||||
hasOldSchema = true;
|
||||
role.permissions = role.permissions || {};
|
||||
role.permissions[permType] = {
|
||||
...role.permissions[permType],
|
||||
...role[permType],
|
||||
};
|
||||
unsetFields[permType] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasOldSchema) {
|
||||
try {
|
||||
logger.info(`Migrating role '${role.name}' from old schema structure`);
|
||||
|
||||
await Role.updateOne(
|
||||
{ _id: role._id },
|
||||
{
|
||||
$set: { permissions: role.permissions },
|
||||
$unset: unsetFields,
|
||||
},
|
||||
);
|
||||
|
||||
const cache = deps.getCache?.(CacheKeys.ROLES);
|
||||
if (cache) {
|
||||
const updatedRole = await Role.findById(role._id).lean().exec();
|
||||
await cache.set(role.name, updatedRole);
|
||||
}
|
||||
|
||||
migratedCount++;
|
||||
logger.info(`Migrated role '${role.name}'`);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to migrate role '${role.name}': ${(error as Error).message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`Migration complete: ${migratedCount} roles migrated`);
|
||||
return migratedCount;
|
||||
} catch (error) {
|
||||
logger.error(`Role schema migration failed: ${(error as Error).message}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
listRoles,
|
||||
initializeRoles,
|
||||
getRoleByName,
|
||||
updateRoleByName,
|
||||
updateAccessPermissions,
|
||||
migrateRoleSchema,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
1093
packages/data-schemas/src/methods/spendTokens.spec.ts
Normal file
1093
packages/data-schemas/src/methods/spendTokens.spec.ts
Normal file
File diff suppressed because it is too large
Load diff
145
packages/data-schemas/src/methods/spendTokens.ts
Normal file
145
packages/data-schemas/src/methods/spendTokens.ts
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
import logger from '~/config/winston';
|
||||
import type { TxData, TransactionResult } from './transaction';
|
||||
|
||||
/** Base transaction context passed by callers — does not include fields added internally */
|
||||
export interface SpendTxData {
|
||||
user: string | import('mongoose').Types.ObjectId;
|
||||
conversationId?: string;
|
||||
model?: string;
|
||||
context?: string;
|
||||
endpointTokenConfig?: Record<string, Record<string, number>> | null;
|
||||
balance?: { enabled?: boolean };
|
||||
transactions?: { enabled?: boolean };
|
||||
valueKey?: string;
|
||||
}
|
||||
|
||||
export function createSpendTokensMethods(
|
||||
_mongoose: typeof import('mongoose'),
|
||||
transactionMethods: {
|
||||
createTransaction: (txData: TxData) => Promise<TransactionResult | undefined>;
|
||||
createStructuredTransaction: (txData: TxData) => Promise<TransactionResult | undefined>;
|
||||
},
|
||||
) {
|
||||
/**
|
||||
* Creates up to two transactions to record the spending of tokens.
|
||||
*/
|
||||
async function spendTokens(
|
||||
txData: SpendTxData,
|
||||
tokenUsage: { promptTokens?: number; completionTokens?: number },
|
||||
) {
|
||||
const { promptTokens, completionTokens } = tokenUsage;
|
||||
logger.debug(
|
||||
`[spendTokens] conversationId: ${txData.conversationId}${
|
||||
txData?.context ? ` | Context: ${txData?.context}` : ''
|
||||
} | Token usage: `,
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
let prompt: TransactionResult | undefined, completion: TransactionResult | undefined;
|
||||
const normalizedPromptTokens = Math.max(promptTokens ?? 0, 0);
|
||||
try {
|
||||
if (promptTokens !== undefined) {
|
||||
prompt = await transactionMethods.createTransaction({
|
||||
...txData,
|
||||
tokenType: 'prompt',
|
||||
rawAmount: promptTokens === 0 ? 0 : -normalizedPromptTokens,
|
||||
inputTokenCount: normalizedPromptTokens,
|
||||
});
|
||||
}
|
||||
|
||||
if (completionTokens !== undefined) {
|
||||
completion = await transactionMethods.createTransaction({
|
||||
...txData,
|
||||
tokenType: 'completion',
|
||||
rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0),
|
||||
inputTokenCount: normalizedPromptTokens,
|
||||
});
|
||||
}
|
||||
|
||||
if (prompt || completion) {
|
||||
logger.debug('[spendTokens] Transaction data record against balance:', {
|
||||
user: txData.user,
|
||||
prompt: prompt?.prompt,
|
||||
promptRate: prompt?.rate,
|
||||
completion: completion?.completion,
|
||||
completionRate: completion?.rate,
|
||||
balance: completion?.balance ?? prompt?.balance,
|
||||
});
|
||||
} else {
|
||||
logger.debug('[spendTokens] No transactions incurred against balance');
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('[spendTokens]', err);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates transactions to record the spending of structured tokens.
|
||||
*/
|
||||
async function spendStructuredTokens(
|
||||
txData: SpendTxData,
|
||||
tokenUsage: {
|
||||
promptTokens?: { input?: number; write?: number; read?: number };
|
||||
completionTokens?: number;
|
||||
},
|
||||
) {
|
||||
const { promptTokens, completionTokens } = tokenUsage;
|
||||
logger.debug(
|
||||
`[spendStructuredTokens] conversationId: ${txData.conversationId}${
|
||||
txData?.context ? ` | Context: ${txData?.context}` : ''
|
||||
} | Token usage: `,
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
let prompt: TransactionResult | undefined, completion: TransactionResult | undefined;
|
||||
try {
|
||||
if (promptTokens) {
|
||||
const input = Math.max(promptTokens.input ?? 0, 0);
|
||||
const write = Math.max(promptTokens.write ?? 0, 0);
|
||||
const read = Math.max(promptTokens.read ?? 0, 0);
|
||||
const totalInputTokens = input + write + read;
|
||||
prompt = await transactionMethods.createStructuredTransaction({
|
||||
...txData,
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -input,
|
||||
writeTokens: -write,
|
||||
readTokens: -read,
|
||||
inputTokenCount: totalInputTokens,
|
||||
});
|
||||
}
|
||||
|
||||
if (completionTokens) {
|
||||
const totalInputTokens = promptTokens
|
||||
? Math.max(promptTokens.input ?? 0, 0) +
|
||||
Math.max(promptTokens.write ?? 0, 0) +
|
||||
Math.max(promptTokens.read ?? 0, 0)
|
||||
: undefined;
|
||||
completion = await transactionMethods.createTransaction({
|
||||
...txData,
|
||||
tokenType: 'completion',
|
||||
rawAmount: -Math.max(completionTokens, 0),
|
||||
inputTokenCount: totalInputTokens,
|
||||
});
|
||||
}
|
||||
|
||||
if (prompt || completion) {
|
||||
logger.debug('[spendStructuredTokens] Transaction data record against balance:', {
|
||||
user: txData.user,
|
||||
prompt: prompt?.prompt,
|
||||
promptRate: prompt?.rate,
|
||||
completion: completion?.completion,
|
||||
completionRate: completion?.rate,
|
||||
balance: completion?.balance ?? prompt?.balance,
|
||||
});
|
||||
} else {
|
||||
logger.debug('[spendStructuredTokens] No transactions incurred against balance');
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('[spendStructuredTokens]', err);
|
||||
}
|
||||
|
||||
return { prompt, completion };
|
||||
}
|
||||
|
||||
return { spendTokens, spendStructuredTokens };
|
||||
}
|
||||
|
||||
export type SpendTokensMethods = ReturnType<typeof createSpendTokensMethods>;
|
||||
38
packages/data-schemas/src/methods/test-helpers.ts
Normal file
38
packages/data-schemas/src/methods/test-helpers.ts
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Inlined utility functions previously imported from @librechat/api.
|
||||
* These are used only by test files in data-schemas.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Finds the first matching pattern in a tokens/values map by reverse-iterating
|
||||
* and checking if the model name (lowercased) includes the key.
|
||||
*
|
||||
* Inlined from @librechat/api findMatchingPattern
|
||||
*/
|
||||
export function findMatchingPattern(
|
||||
modelName: string,
|
||||
tokensMap: Record<string, unknown>,
|
||||
): string | undefined {
|
||||
const keys = Object.keys(tokensMap);
|
||||
const lowerModelName = modelName.toLowerCase();
|
||||
for (let i = keys.length - 1; i >= 0; i--) {
|
||||
const modelKey = keys[i];
|
||||
if (lowerModelName.includes(modelKey)) {
|
||||
return modelKey;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Matches a model name to a canonical key. When no maxTokensMap is available
|
||||
* (as in data-schemas tests), returns the model name as-is.
|
||||
*
|
||||
* Inlined from @librechat/api matchModelName (simplified for test use)
|
||||
*/
|
||||
export function matchModelName(modelName: string, _endpoint?: string): string | undefined {
|
||||
if (typeof modelName !== 'string') {
|
||||
return undefined;
|
||||
}
|
||||
return modelName;
|
||||
}
|
||||
97
packages/data-schemas/src/methods/toolCall.ts
Normal file
97
packages/data-schemas/src/methods/toolCall.ts
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
import type { Model } from 'mongoose';
|
||||
|
||||
interface IToolCallData {
|
||||
messageId?: string;
|
||||
conversationId?: string;
|
||||
user?: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export function createToolCallMethods(mongoose: typeof import('mongoose')) {
|
||||
/**
|
||||
* Create a new tool call
|
||||
*/
|
||||
async function createToolCall(toolCallData: IToolCallData) {
|
||||
try {
|
||||
const ToolCall = mongoose.models.ToolCall as Model<IToolCallData>;
|
||||
return await ToolCall.create(toolCallData);
|
||||
} catch (error) {
|
||||
throw new Error(`Error creating tool call: ${(error as Error).message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a tool call by ID
|
||||
*/
|
||||
async function getToolCallById(id: string) {
|
||||
try {
|
||||
const ToolCall = mongoose.models.ToolCall as Model<IToolCallData>;
|
||||
return await ToolCall.findById(id).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error fetching tool call: ${(error as Error).message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get tool calls by message ID and user
|
||||
*/
|
||||
async function getToolCallsByMessage(messageId: string, userId: string) {
|
||||
try {
|
||||
const ToolCall = mongoose.models.ToolCall as Model<IToolCallData>;
|
||||
return await ToolCall.find({ messageId, user: userId }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error fetching tool calls: ${(error as Error).message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get tool calls by conversation ID and user
|
||||
*/
|
||||
async function getToolCallsByConvo(conversationId: string, userId: string) {
|
||||
try {
|
||||
const ToolCall = mongoose.models.ToolCall as Model<IToolCallData>;
|
||||
return await ToolCall.find({ conversationId, user: userId }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error fetching tool calls: ${(error as Error).message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a tool call
|
||||
*/
|
||||
async function updateToolCall(id: string, updateData: Partial<IToolCallData>) {
|
||||
try {
|
||||
const ToolCall = mongoose.models.ToolCall as Model<IToolCallData>;
|
||||
return await ToolCall.findByIdAndUpdate(id, updateData, { new: true }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error updating tool call: ${(error as Error).message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete tool calls by user and optionally conversation
|
||||
*/
|
||||
async function deleteToolCalls(userId: string, conversationId?: string) {
|
||||
try {
|
||||
const ToolCall = mongoose.models.ToolCall as Model<IToolCallData>;
|
||||
const query: Record<string, string> = { user: userId };
|
||||
if (conversationId) {
|
||||
query.conversationId = conversationId;
|
||||
}
|
||||
return await ToolCall.deleteMany(query);
|
||||
} catch (error) {
|
||||
throw new Error(`Error deleting tool call: ${(error as Error).message}`);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
createToolCall,
|
||||
updateToolCall,
|
||||
deleteToolCalls,
|
||||
getToolCallById,
|
||||
getToolCallsByConvo,
|
||||
getToolCallsByMessage,
|
||||
};
|
||||
}
|
||||
|
||||
export type ToolCallMethods = ReturnType<typeof createToolCallMethods>;
|
||||
905
packages/data-schemas/src/methods/transaction.spec.ts
Normal file
905
packages/data-schemas/src/methods/transaction.spec.ts
Normal file
|
|
@ -0,0 +1,905 @@
|
|||
import mongoose from 'mongoose';
|
||||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import type { ITransaction } from '~/schema/transaction';
|
||||
import type { TxData } from './transaction';
|
||||
import type { IBalance } from '..';
|
||||
import { createTxMethods, tokenValues, premiumTokenValues } from './tx';
|
||||
import { matchModelName, findMatchingPattern } from './test-helpers';
|
||||
import { createSpendTokensMethods } from './spendTokens';
|
||||
import { createTransactionMethods } from './transaction';
|
||||
import { createModels } from '~/models';
|
||||
|
||||
jest.mock('~/config/winston', () => ({
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
info: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
}));
|
||||
|
||||
let mongoServer: InstanceType<typeof MongoMemoryServer>;
|
||||
let Balance: mongoose.Model<IBalance>;
|
||||
let Transaction: mongoose.Model<ITransaction>;
|
||||
let spendTokens: ReturnType<typeof createSpendTokensMethods>['spendTokens'];
|
||||
let spendStructuredTokens: ReturnType<typeof createSpendTokensMethods>['spendStructuredTokens'];
|
||||
let createTransaction: ReturnType<typeof createTransactionMethods>['createTransaction'];
|
||||
let createStructuredTransaction: ReturnType<
|
||||
typeof createTransactionMethods
|
||||
>['createStructuredTransaction'];
|
||||
let getMultiplier: ReturnType<typeof createTxMethods>['getMultiplier'];
|
||||
let getCacheMultiplier: ReturnType<typeof createTxMethods>['getCacheMultiplier'];
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
|
||||
// Register models
|
||||
const models = createModels(mongoose);
|
||||
Object.assign(mongoose.models, models);
|
||||
|
||||
Balance = mongoose.models.Balance;
|
||||
Transaction = mongoose.models.Transaction;
|
||||
|
||||
// Create methods from factories (following the chain in methods/index.ts)
|
||||
const txMethods = createTxMethods(mongoose, { matchModelName, findMatchingPattern });
|
||||
getMultiplier = txMethods.getMultiplier;
|
||||
getCacheMultiplier = txMethods.getCacheMultiplier;
|
||||
|
||||
const transactionMethods = createTransactionMethods(mongoose, {
|
||||
getMultiplier: txMethods.getMultiplier,
|
||||
getCacheMultiplier: txMethods.getCacheMultiplier,
|
||||
});
|
||||
createTransaction = transactionMethods.createTransaction;
|
||||
createStructuredTransaction = transactionMethods.createStructuredTransaction;
|
||||
|
||||
const spendMethods = createSpendTokensMethods(mongoose, {
|
||||
createTransaction: transactionMethods.createTransaction,
|
||||
createStructuredTransaction: transactionMethods.createStructuredTransaction,
|
||||
});
|
||||
spendTokens = spendMethods.spendTokens;
|
||||
spendStructuredTokens = spendMethods.spendStructuredTokens;
|
||||
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await mongoose.connection.dropDatabase();
|
||||
});
|
||||
|
||||
describe('Regular Token Spending Tests', () => {
|
||||
test('Balance should decrease when spending tokens with spendTokens', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000; // $10.00
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: 100,
|
||||
completionTokens: 50,
|
||||
};
|
||||
|
||||
// Act
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
// Assert
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
|
||||
const completionMultiplier = getMultiplier({ model, tokenType: 'completion' });
|
||||
const expectedTotalCost = 100 * promptMultiplier + 50 * completionMultiplier;
|
||||
const expectedBalance = initialBalance - expectedTotalCost;
|
||||
|
||||
expect(updatedBalance?.tokenCredits).toBeCloseTo(expectedBalance, 0);
|
||||
});
|
||||
|
||||
test('spendTokens should handle zero completion tokens', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: 100,
|
||||
completionTokens: 0,
|
||||
};
|
||||
|
||||
// Act
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
// Assert
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
|
||||
const expectedCost = 100 * promptMultiplier;
|
||||
expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||
});
|
||||
|
||||
test('spendTokens should handle undefined token counts', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {};
|
||||
|
||||
// Act
|
||||
const result = await spendTokens(txData, tokenUsage);
|
||||
|
||||
// Assert: No transaction should be created
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
test('spendTokens should handle only prompt tokens', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = { promptTokens: 100 };
|
||||
|
||||
// Act
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
// Assert
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
|
||||
const expectedCost = 100 * promptMultiplier;
|
||||
expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||
});
|
||||
|
||||
test('spendTokens should not update balance when balance feature is disabled', async () => {
|
||||
// Arrange: Balance config is now passed directly in txData
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: false },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: 100,
|
||||
completionTokens: 50,
|
||||
};
|
||||
|
||||
// Act
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
// Assert: Balance should remain unchanged.
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
expect(updatedBalance?.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Structured Token Spending Tests', () => {
|
||||
test('Balance should decrease and rawAmount should be set when spending a large number of structured tokens', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 17613154.55; // $17.61
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'c23a18da-706c-470a-ac28-ec87ed065199',
|
||||
model,
|
||||
context: 'message',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
input: 11,
|
||||
write: 140522,
|
||||
read: 0,
|
||||
},
|
||||
completionTokens: 5,
|
||||
};
|
||||
|
||||
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
|
||||
const completionMultiplier = getMultiplier({ model, tokenType: 'completion' });
|
||||
const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' });
|
||||
const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' });
|
||||
|
||||
// Act
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
// Calculate expected costs.
|
||||
const expectedPromptCost =
|
||||
tokenUsage.promptTokens.input * promptMultiplier +
|
||||
tokenUsage.promptTokens.write * (writeMultiplier ?? 0) +
|
||||
tokenUsage.promptTokens.read * (readMultiplier ?? 0);
|
||||
const expectedCompletionCost = tokenUsage.completionTokens * completionMultiplier;
|
||||
const expectedTotalCost = expectedPromptCost + expectedCompletionCost;
|
||||
const expectedBalance = initialBalance - expectedTotalCost;
|
||||
|
||||
// Assert
|
||||
expect(result?.completion?.balance).toBeLessThan(initialBalance);
|
||||
const allowedDifference = 100;
|
||||
expect(Math.abs((result?.completion?.balance ?? 0) - expectedBalance)).toBeLessThan(
|
||||
allowedDifference,
|
||||
);
|
||||
const balanceDecrease = initialBalance - (result?.completion?.balance ?? 0);
|
||||
expect(balanceDecrease).toBeCloseTo(expectedTotalCost, 0);
|
||||
|
||||
const expectedPromptTokenValue = -expectedPromptCost;
|
||||
const expectedCompletionTokenValue = -expectedCompletionCost;
|
||||
expect(result?.prompt?.prompt).toBeCloseTo(expectedPromptTokenValue, 1);
|
||||
expect(result?.completion?.completion).toBe(expectedCompletionTokenValue);
|
||||
});
|
||||
|
||||
test('should handle zero completion tokens in structured spending', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 17613154.55;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model,
|
||||
context: 'message',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
input: 10,
|
||||
write: 100,
|
||||
read: 5,
|
||||
},
|
||||
completionTokens: 0,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
// Assert
|
||||
expect(result.prompt).toBeDefined();
|
||||
expect(result.completion).toBeUndefined();
|
||||
expect(result?.prompt?.prompt).toBeLessThan(0);
|
||||
});
|
||||
|
||||
test('should handle only prompt tokens in structured spending', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 17613154.55;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model,
|
||||
context: 'message',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
input: 10,
|
||||
write: 100,
|
||||
read: 5,
|
||||
},
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
// Assert
|
||||
expect(result.prompt).toBeDefined();
|
||||
expect(result.completion).toBeUndefined();
|
||||
expect(result?.prompt?.prompt).toBeLessThan(0);
|
||||
});
|
||||
|
||||
test('should handle undefined token counts in structured spending', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 17613154.55;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model,
|
||||
context: 'message',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {};
|
||||
|
||||
// Act
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual({
|
||||
prompt: undefined,
|
||||
completion: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle incomplete context for completion tokens', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 17613154.55;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model,
|
||||
context: 'incomplete',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
input: 10,
|
||||
write: 100,
|
||||
read: 5,
|
||||
},
|
||||
completionTokens: 50,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
// Assert:
|
||||
// (Assuming a multiplier for completion of 15 and a cancel rate of 1.15 as noted in the original test.)
|
||||
expect(result?.completion?.completion).toBeCloseTo(-50 * 15 * 1.15, 0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('NaN Handling Tests', () => {
|
||||
test('should skip transaction creation when rawAmount is NaN', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData: TxData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: NaN,
|
||||
tokenType: 'prompt',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: No transaction should be created and balance remains unchanged.
|
||||
expect(result).toBeUndefined();
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance?.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Transactions Config Tests', () => {
|
||||
test('createTransaction should not save when transactions.enabled is false', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData: TxData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
transactions: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: No transaction should be created
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(0);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance?.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
|
||||
test('createTransaction should save when transactions.enabled is true', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData: TxData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
transactions: { enabled: true },
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created
|
||||
expect(result).toBeDefined();
|
||||
expect(result?.balance).toBeLessThan(initialBalance);
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
expect(transactions[0].rawAmount).toBe(-100);
|
||||
});
|
||||
|
||||
test('createTransaction should save when balance.enabled is true even if transactions config is missing', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData: TxData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
balance: { enabled: true },
|
||||
// No transactions config provided
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created (backward compatibility)
|
||||
expect(result).toBeDefined();
|
||||
expect(result?.balance).toBeLessThan(initialBalance);
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
});
|
||||
|
||||
test('createTransaction should save transaction but not update balance when balance is disabled but transactions enabled', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData: TxData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
transactions: { enabled: true },
|
||||
balance: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created but balance unchanged
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
expect(transactions[0].rawAmount).toBe(-100);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance?.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
|
||||
test('createStructuredTransaction should not save when transactions.enabled is false', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData: TxData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'message',
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -10,
|
||||
writeTokens: -100,
|
||||
readTokens: -5,
|
||||
transactions: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createStructuredTransaction(txData);
|
||||
|
||||
// Assert: No transaction should be created
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(0);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance?.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
|
||||
test('createStructuredTransaction should save transaction but not update balance when balance is disabled but transactions enabled', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData: TxData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'message',
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -10,
|
||||
writeTokens: -100,
|
||||
readTokens: -5,
|
||||
transactions: { enabled: true },
|
||||
balance: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createStructuredTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created but balance unchanged
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
expect(transactions[0].inputTokens).toBe(-10);
|
||||
expect(transactions[0].writeTokens).toBe(-100);
|
||||
expect(transactions[0].readTokens).toBe(-5);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance?.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculateTokenValue Edge Cases', () => {
|
||||
test('should derive multiplier from model when valueKey is not provided', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 100000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-4';
|
||||
const promptTokens = 1000;
|
||||
|
||||
const result = await createTransaction({
|
||||
user: userId,
|
||||
conversationId: 'test-no-valuekey',
|
||||
model,
|
||||
tokenType: 'prompt',
|
||||
rawAmount: -promptTokens,
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
});
|
||||
|
||||
const expectedRate = getMultiplier({ model, tokenType: 'prompt' });
|
||||
expect(result?.rate).toBe(expectedRate);
|
||||
|
||||
const tx = await Transaction.findOne({ user: userId });
|
||||
expect(tx?.tokenValue).toBe(-promptTokens * expectedRate);
|
||||
expect(tx?.rate).toBe(expectedRate);
|
||||
});
|
||||
|
||||
test('should derive valueKey and apply correct rate for an unknown model with tokenType', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 100000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
await createTransaction({
|
||||
user: userId,
|
||||
conversationId: 'test-unknown-model',
|
||||
model: 'some-unrecognized-model-xyz',
|
||||
tokenType: 'prompt',
|
||||
rawAmount: -500,
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
});
|
||||
|
||||
const tx = await Transaction.findOne({ user: userId });
|
||||
expect(tx?.rate).toBeDefined();
|
||||
expect(tx?.rate).toBeGreaterThan(0);
|
||||
expect(tx?.tokenValue).toBe((tx?.rawAmount ?? 0) * (tx?.rate ?? 0));
|
||||
});
|
||||
|
||||
test('should correctly apply model-derived multiplier without valueKey for completion', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 100000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-opus-4-6';
|
||||
const completionTokens = 500;
|
||||
|
||||
const result = await createTransaction({
|
||||
user: userId,
|
||||
conversationId: 'test-completion-no-valuekey',
|
||||
model,
|
||||
tokenType: 'completion',
|
||||
rawAmount: -completionTokens,
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
});
|
||||
|
||||
const expectedRate = getMultiplier({ model, tokenType: 'completion' });
|
||||
expect(expectedRate).toBe(tokenValues[model].completion);
|
||||
expect(result?.rate).toBe(expectedRate);
|
||||
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
expect(updatedBalance?.tokenCredits).toBeCloseTo(
|
||||
initialBalance - completionTokens * expectedRate,
|
||||
0,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Premium Token Pricing Integration Tests', () => {
|
||||
test('spendTokens should apply standard pricing when prompt tokens are below premium threshold', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 100000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-opus-4-6';
|
||||
const promptTokens = 100000;
|
||||
const completionTokens = 500;
|
||||
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-premium-below',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
await spendTokens(txData, { promptTokens, completionTokens });
|
||||
|
||||
const standardPromptRate = tokenValues[model].prompt;
|
||||
const standardCompletionRate = tokenValues[model].completion;
|
||||
const expectedCost =
|
||||
promptTokens * standardPromptRate + completionTokens * standardCompletionRate;
|
||||
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||
});
|
||||
|
||||
test('spendTokens should apply premium pricing when prompt tokens exceed premium threshold', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 100000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-opus-4-6';
|
||||
const promptTokens = 250000;
|
||||
const completionTokens = 500;
|
||||
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-premium-above',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
await spendTokens(txData, { promptTokens, completionTokens });
|
||||
|
||||
const premiumPromptRate = premiumTokenValues[model].prompt;
|
||||
const premiumCompletionRate = premiumTokenValues[model].completion;
|
||||
const expectedCost =
|
||||
promptTokens * premiumPromptRate + completionTokens * premiumCompletionRate;
|
||||
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||
});
|
||||
|
||||
test('spendTokens should apply standard pricing at exactly the premium threshold', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 100000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-opus-4-6';
|
||||
const promptTokens = premiumTokenValues[model].threshold;
|
||||
const completionTokens = 500;
|
||||
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-premium-exact',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
await spendTokens(txData, { promptTokens, completionTokens });
|
||||
|
||||
const standardPromptRate = tokenValues[model].prompt;
|
||||
const standardCompletionRate = tokenValues[model].completion;
|
||||
const expectedCost =
|
||||
promptTokens * standardPromptRate + completionTokens * standardCompletionRate;
|
||||
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||
});
|
||||
|
||||
test('spendStructuredTokens should apply premium pricing when total input tokens exceed threshold', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 100000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-opus-4-6';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-structured-premium',
|
||||
model,
|
||||
context: 'message',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
input: 200000,
|
||||
write: 10000,
|
||||
read: 5000,
|
||||
},
|
||||
completionTokens: 1000,
|
||||
};
|
||||
|
||||
const totalInput =
|
||||
tokenUsage.promptTokens.input + tokenUsage.promptTokens.write + tokenUsage.promptTokens.read;
|
||||
|
||||
await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
const premiumPromptRate = premiumTokenValues[model].prompt;
|
||||
const premiumCompletionRate = premiumTokenValues[model].completion;
|
||||
const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' });
|
||||
const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' });
|
||||
|
||||
const expectedPromptCost =
|
||||
tokenUsage.promptTokens.input * premiumPromptRate +
|
||||
tokenUsage.promptTokens.write * (writeMultiplier ?? 0) +
|
||||
tokenUsage.promptTokens.read * (readMultiplier ?? 0);
|
||||
const expectedCompletionCost = tokenUsage.completionTokens * premiumCompletionRate;
|
||||
const expectedTotalCost = expectedPromptCost + expectedCompletionCost;
|
||||
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
expect(totalInput).toBeGreaterThan(premiumTokenValues[model].threshold);
|
||||
expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0);
|
||||
});
|
||||
|
||||
test('spendStructuredTokens should apply standard pricing when total input tokens are below threshold', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 100000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-opus-4-6';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-structured-standard',
|
||||
model,
|
||||
context: 'message',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
input: 50000,
|
||||
write: 10000,
|
||||
read: 5000,
|
||||
},
|
||||
completionTokens: 1000,
|
||||
};
|
||||
|
||||
const totalInput =
|
||||
tokenUsage.promptTokens.input + tokenUsage.promptTokens.write + tokenUsage.promptTokens.read;
|
||||
|
||||
await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
const standardPromptRate = tokenValues[model].prompt;
|
||||
const standardCompletionRate = tokenValues[model].completion;
|
||||
const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' });
|
||||
const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' });
|
||||
|
||||
const expectedPromptCost =
|
||||
tokenUsage.promptTokens.input * standardPromptRate +
|
||||
tokenUsage.promptTokens.write * (writeMultiplier ?? 0) +
|
||||
tokenUsage.promptTokens.read * (readMultiplier ?? 0);
|
||||
const expectedCompletionCost = tokenUsage.completionTokens * standardCompletionRate;
|
||||
const expectedTotalCost = expectedPromptCost + expectedCompletionCost;
|
||||
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
expect(totalInput).toBeLessThanOrEqual(premiumTokenValues[model].threshold);
|
||||
expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0);
|
||||
});
|
||||
|
||||
test('non-premium models should not be affected by inputTokenCount regardless of prompt size', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 100000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-opus-4-5';
|
||||
const promptTokens = 300000;
|
||||
const completionTokens = 500;
|
||||
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-no-premium',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
await spendTokens(txData, { promptTokens, completionTokens });
|
||||
|
||||
const standardPromptRate = getMultiplier({ model, tokenType: 'prompt' });
|
||||
const standardCompletionRate = getMultiplier({ model, tokenType: 'completion' });
|
||||
const expectedCost =
|
||||
promptTokens * standardPromptRate + completionTokens * standardCompletionRate;
|
||||
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||
});
|
||||
});
|
||||
419
packages/data-schemas/src/methods/transaction.ts
Normal file
419
packages/data-schemas/src/methods/transaction.ts
Normal file
|
|
@ -0,0 +1,419 @@
|
|||
import logger from '~/config/winston';
|
||||
import type { FilterQuery, Model, Types } from 'mongoose';
|
||||
import type { ITransaction } from '~/schema/transaction';
|
||||
import type { IBalance, IBalanceUpdate } from '~/types';
|
||||
|
||||
const cancelRate = 1.15;
|
||||
|
||||
type MultiplierParams = {
|
||||
model?: string;
|
||||
valueKey?: string;
|
||||
tokenType?: 'prompt' | 'completion';
|
||||
inputTokenCount?: number;
|
||||
endpointTokenConfig?: Record<string, Record<string, number>>;
|
||||
};
|
||||
|
||||
type CacheMultiplierParams = {
|
||||
cacheType?: 'write' | 'read';
|
||||
model?: string;
|
||||
endpointTokenConfig?: Record<string, Record<string, number>>;
|
||||
};
|
||||
|
||||
/** Fields read/written by the internal token value calculators */
|
||||
interface InternalTxDoc {
|
||||
valueKey?: string;
|
||||
tokenType?: 'prompt' | 'completion' | 'credits';
|
||||
model?: string;
|
||||
endpointTokenConfig?: Record<string, Record<string, number>> | null;
|
||||
inputTokenCount?: number;
|
||||
rawAmount?: number;
|
||||
context?: string;
|
||||
rate?: number;
|
||||
tokenValue?: number;
|
||||
rateDetail?: Record<string, number>;
|
||||
inputTokens?: number;
|
||||
writeTokens?: number;
|
||||
readTokens?: number;
|
||||
}
|
||||
|
||||
/** Input data for creating a transaction */
|
||||
export interface TxData {
|
||||
user: string | Types.ObjectId;
|
||||
conversationId?: string;
|
||||
model?: string;
|
||||
context?: string;
|
||||
tokenType?: 'prompt' | 'completion' | 'credits';
|
||||
rawAmount?: number;
|
||||
valueKey?: string;
|
||||
endpointTokenConfig?: Record<string, Record<string, number>> | null;
|
||||
inputTokenCount?: number;
|
||||
inputTokens?: number;
|
||||
writeTokens?: number;
|
||||
readTokens?: number;
|
||||
balance?: { enabled?: boolean };
|
||||
transactions?: { enabled?: boolean };
|
||||
}
|
||||
|
||||
/** Return value from a successful transaction that also updates the balance */
|
||||
export interface TransactionResult {
|
||||
rate: number;
|
||||
user: string;
|
||||
balance: number;
|
||||
prompt?: number;
|
||||
completion?: number;
|
||||
credits?: number;
|
||||
}
|
||||
|
||||
export function createTransactionMethods(
|
||||
mongoose: typeof import('mongoose'),
|
||||
txMethods: {
|
||||
getMultiplier: (params: MultiplierParams) => number;
|
||||
getCacheMultiplier: (params: CacheMultiplierParams) => number | null;
|
||||
},
|
||||
) {
|
||||
/** Calculate and set the tokenValue for a transaction */
|
||||
function calculateTokenValue(txn: InternalTxDoc) {
|
||||
const { valueKey, tokenType, model, endpointTokenConfig, inputTokenCount } = txn;
|
||||
const multiplier = Math.abs(
|
||||
txMethods.getMultiplier({
|
||||
valueKey,
|
||||
tokenType: tokenType as 'prompt' | 'completion' | undefined,
|
||||
model,
|
||||
endpointTokenConfig: endpointTokenConfig ?? undefined,
|
||||
inputTokenCount,
|
||||
}),
|
||||
);
|
||||
txn.rate = multiplier;
|
||||
txn.tokenValue = (txn.rawAmount ?? 0) * multiplier;
|
||||
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') {
|
||||
txn.tokenValue = Math.ceil((txn.tokenValue ?? 0) * cancelRate);
|
||||
txn.rate = (txn.rate ?? 0) * cancelRate;
|
||||
}
|
||||
}
|
||||
|
||||
/** Calculate token value for structured tokens */
|
||||
function calculateStructuredTokenValue(txn: InternalTxDoc) {
|
||||
if (!txn.tokenType) {
|
||||
txn.tokenValue = txn.rawAmount;
|
||||
return;
|
||||
}
|
||||
|
||||
const { model, endpointTokenConfig, inputTokenCount } = txn;
|
||||
const etConfig = endpointTokenConfig ?? undefined;
|
||||
|
||||
if (txn.tokenType === 'prompt') {
|
||||
const inputMultiplier = txMethods.getMultiplier({
|
||||
tokenType: 'prompt',
|
||||
model,
|
||||
endpointTokenConfig: etConfig,
|
||||
inputTokenCount,
|
||||
});
|
||||
const writeMultiplier =
|
||||
txMethods.getCacheMultiplier({
|
||||
cacheType: 'write',
|
||||
model,
|
||||
endpointTokenConfig: etConfig,
|
||||
}) ?? inputMultiplier;
|
||||
const readMultiplier =
|
||||
txMethods.getCacheMultiplier({ cacheType: 'read', model, endpointTokenConfig: etConfig }) ??
|
||||
inputMultiplier;
|
||||
|
||||
txn.rateDetail = {
|
||||
input: inputMultiplier,
|
||||
write: writeMultiplier,
|
||||
read: readMultiplier,
|
||||
};
|
||||
|
||||
const totalPromptTokens =
|
||||
Math.abs(txn.inputTokens ?? 0) +
|
||||
Math.abs(txn.writeTokens ?? 0) +
|
||||
Math.abs(txn.readTokens ?? 0);
|
||||
|
||||
if (totalPromptTokens > 0) {
|
||||
txn.rate =
|
||||
(Math.abs(inputMultiplier * (txn.inputTokens ?? 0)) +
|
||||
Math.abs(writeMultiplier * (txn.writeTokens ?? 0)) +
|
||||
Math.abs(readMultiplier * (txn.readTokens ?? 0))) /
|
||||
totalPromptTokens;
|
||||
} else {
|
||||
txn.rate = Math.abs(inputMultiplier);
|
||||
}
|
||||
|
||||
txn.tokenValue = -(
|
||||
Math.abs(txn.inputTokens ?? 0) * inputMultiplier +
|
||||
Math.abs(txn.writeTokens ?? 0) * writeMultiplier +
|
||||
Math.abs(txn.readTokens ?? 0) * readMultiplier
|
||||
);
|
||||
|
||||
txn.rawAmount = -totalPromptTokens;
|
||||
} else if (txn.tokenType === 'completion') {
|
||||
const multiplier = txMethods.getMultiplier({
|
||||
tokenType: txn.tokenType,
|
||||
model,
|
||||
endpointTokenConfig: etConfig,
|
||||
inputTokenCount,
|
||||
});
|
||||
txn.rate = Math.abs(multiplier);
|
||||
txn.tokenValue = -Math.abs(txn.rawAmount ?? 0) * multiplier;
|
||||
txn.rawAmount = -Math.abs(txn.rawAmount ?? 0);
|
||||
}
|
||||
|
||||
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') {
|
||||
txn.tokenValue = Math.ceil((txn.tokenValue ?? 0) * cancelRate);
|
||||
txn.rate = (txn.rate ?? 0) * cancelRate;
|
||||
if (txn.rateDetail) {
|
||||
txn.rateDetail = Object.fromEntries(
|
||||
Object.entries(txn.rateDetail).map(([k, v]) => [k, v * cancelRate]),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a user's token balance using optimistic concurrency control.
|
||||
* Always returns an IBalance or throws after exhausting retries.
|
||||
*/
|
||||
async function updateBalance({
|
||||
user,
|
||||
incrementValue,
|
||||
setValues,
|
||||
}: {
|
||||
user: string;
|
||||
incrementValue: number;
|
||||
setValues?: IBalanceUpdate;
|
||||
}): Promise<IBalance> {
|
||||
const Balance = mongoose.models.Balance as Model<IBalance>;
|
||||
const maxRetries = 10;
|
||||
let delay = 50;
|
||||
let lastError: Error | null = null;
|
||||
|
||||
for (let attempt = 1; attempt <= maxRetries; attempt++) {
|
||||
let currentBalanceDoc;
|
||||
try {
|
||||
currentBalanceDoc = await Balance.findOne({ user }).lean();
|
||||
const currentCredits = currentBalanceDoc ? currentBalanceDoc.tokenCredits : 0;
|
||||
const potentialNewCredits = currentCredits + incrementValue;
|
||||
const newCredits = Math.max(0, potentialNewCredits);
|
||||
|
||||
const updatePayload = {
|
||||
$set: {
|
||||
tokenCredits: newCredits,
|
||||
...(setValues || {}),
|
||||
},
|
||||
};
|
||||
|
||||
let updatedBalance: IBalance | null = null;
|
||||
if (currentBalanceDoc) {
|
||||
updatedBalance = await Balance.findOneAndUpdate(
|
||||
{ user, tokenCredits: currentCredits },
|
||||
updatePayload,
|
||||
{ new: true },
|
||||
).lean();
|
||||
|
||||
if (updatedBalance) {
|
||||
return updatedBalance;
|
||||
}
|
||||
lastError = new Error(`Concurrency conflict for user ${user} on attempt ${attempt}.`);
|
||||
} else {
|
||||
try {
|
||||
updatedBalance = await Balance.findOneAndUpdate({ user }, updatePayload, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
}).lean();
|
||||
|
||||
if (updatedBalance) {
|
||||
return updatedBalance;
|
||||
}
|
||||
lastError = new Error(
|
||||
`Upsert race condition suspected for user ${user} on attempt ${attempt}.`,
|
||||
);
|
||||
} catch (error: unknown) {
|
||||
if ((error as { code?: number }).code === 11000) {
|
||||
lastError = error as Error;
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[updateBalance] Error during attempt ${attempt} for user ${user}:`, error);
|
||||
lastError = error as Error;
|
||||
}
|
||||
|
||||
if (attempt < maxRetries) {
|
||||
const jitter = Math.random() * delay * 0.5;
|
||||
await new Promise((resolve) => setTimeout(resolve, delay + jitter));
|
||||
delay = Math.min(delay * 2, 2000);
|
||||
}
|
||||
}
|
||||
|
||||
logger.error(
|
||||
`[updateBalance] Failed to update balance for user ${user} after ${maxRetries} attempts.`,
|
||||
);
|
||||
throw (
|
||||
lastError ||
|
||||
new Error(
|
||||
`Failed to update balance for user ${user} after maximum retries due to persistent conflicts.`,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an auto-refill transaction that also updates balance.
|
||||
*/
|
||||
async function createAutoRefillTransaction(txData: TxData) {
|
||||
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
|
||||
return;
|
||||
}
|
||||
const Transaction = mongoose.models.Transaction;
|
||||
const transaction = new Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
transaction.inputTokenCount = txData.inputTokenCount;
|
||||
calculateTokenValue(transaction);
|
||||
await transaction.save();
|
||||
|
||||
const balanceResponse = await updateBalance({
|
||||
user: transaction.user as string,
|
||||
incrementValue: txData.rawAmount ?? 0,
|
||||
setValues: { lastRefill: new Date() },
|
||||
});
|
||||
const result = {
|
||||
rate: transaction.rate as number,
|
||||
user: transaction.user.toString() as string,
|
||||
balance: balanceResponse.tokenCredits,
|
||||
transaction,
|
||||
};
|
||||
logger.debug('[Balance.check] Auto-refill performed', result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a transaction and updates the balance.
|
||||
*/
|
||||
async function createTransaction(_txData: TxData): Promise<TransactionResult | undefined> {
|
||||
const { balance, transactions, ...txData } = _txData;
|
||||
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (transactions?.enabled === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const Transaction = mongoose.models.Transaction;
|
||||
const transaction = new Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
transaction.inputTokenCount = txData.inputTokenCount;
|
||||
calculateTokenValue(transaction);
|
||||
|
||||
await transaction.save();
|
||||
if (!balance?.enabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
const incrementValue = transaction.tokenValue as number;
|
||||
const balanceResponse = await updateBalance({
|
||||
user: transaction.user as string,
|
||||
incrementValue,
|
||||
});
|
||||
|
||||
return {
|
||||
rate: transaction.rate as number,
|
||||
user: transaction.user.toString() as string,
|
||||
balance: balanceResponse.tokenCredits,
|
||||
[transaction.tokenType as string]: incrementValue,
|
||||
} as TransactionResult;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a structured transaction and updates the balance.
|
||||
*/
|
||||
async function createStructuredTransaction(
|
||||
_txData: TxData,
|
||||
): Promise<TransactionResult | undefined> {
|
||||
const { balance, transactions, ...txData } = _txData;
|
||||
if (transactions?.enabled === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const Transaction = mongoose.models.Transaction;
|
||||
const transaction = new Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
transaction.inputTokenCount = txData.inputTokenCount;
|
||||
|
||||
calculateStructuredTokenValue(transaction);
|
||||
|
||||
await transaction.save();
|
||||
|
||||
if (!balance?.enabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
const incrementValue = transaction.tokenValue as number;
|
||||
|
||||
const balanceResponse = await updateBalance({
|
||||
user: transaction.user as string,
|
||||
incrementValue,
|
||||
});
|
||||
|
||||
return {
|
||||
rate: transaction.rate as number,
|
||||
user: transaction.user.toString() as string,
|
||||
balance: balanceResponse.tokenCredits,
|
||||
[transaction.tokenType as string]: incrementValue,
|
||||
} as TransactionResult;
|
||||
}
|
||||
|
||||
/**
|
||||
* Queries and retrieves transactions based on a given filter.
|
||||
*/
|
||||
async function getTransactions(filter: FilterQuery<ITransaction>) {
|
||||
try {
|
||||
const Transaction = mongoose.models.Transaction;
|
||||
return await Transaction.find(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error querying transactions:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/** Retrieves a user's balance record. */
|
||||
async function findBalanceByUser(user: string): Promise<IBalance | null> {
|
||||
const Balance = mongoose.models.Balance as Model<IBalance>;
|
||||
return Balance.findOne({ user }).lean();
|
||||
}
|
||||
|
||||
/** Upserts balance fields for a user. */
|
||||
async function upsertBalanceFields(
|
||||
user: string,
|
||||
fields: IBalanceUpdate,
|
||||
): Promise<IBalance | null> {
|
||||
const Balance = mongoose.models.Balance as Model<IBalance>;
|
||||
return Balance.findOneAndUpdate({ user }, { $set: fields }, { upsert: true, new: true }).lean();
|
||||
}
|
||||
|
||||
/** Deletes transactions matching a filter. */
|
||||
async function deleteTransactions(filter: FilterQuery<ITransaction>) {
|
||||
const Transaction = mongoose.models.Transaction;
|
||||
return Transaction.deleteMany(filter);
|
||||
}
|
||||
|
||||
/** Deletes balance records matching a filter. */
|
||||
async function deleteBalances(filter: FilterQuery<IBalance>) {
|
||||
const Balance = mongoose.models.Balance as Model<IBalance>;
|
||||
return Balance.deleteMany(filter);
|
||||
}
|
||||
|
||||
return {
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
getTransactions,
|
||||
deleteTransactions,
|
||||
deleteBalances,
|
||||
createTransaction,
|
||||
createAutoRefillTransaction,
|
||||
createStructuredTransaction,
|
||||
};
|
||||
}
|
||||
|
||||
export type TransactionMethods = ReturnType<typeof createTransactionMethods>;
|
||||
2085
packages/data-schemas/src/methods/tx.spec.ts
Normal file
2085
packages/data-schemas/src/methods/tx.spec.ts
Normal file
File diff suppressed because it is too large
Load diff
459
packages/data-schemas/src/methods/tx.ts
Normal file
459
packages/data-schemas/src/methods/tx.ts
Normal file
|
|
@ -0,0 +1,459 @@
|
|||
/**
|
||||
* Token Pricing Configuration
|
||||
*
|
||||
* IMPORTANT: Key Ordering for Pattern Matching
|
||||
* ============================================
|
||||
* The `findMatchingPattern` function iterates through object keys in REVERSE order
|
||||
* (last-defined keys are checked first) and uses `modelName.includes(key)` for matching.
|
||||
*
|
||||
* This means:
|
||||
* 1. BASE PATTERNS must be defined FIRST (e.g., "kimi", "moonshot")
|
||||
* 2. SPECIFIC PATTERNS must be defined AFTER their base patterns (e.g., "kimi-k2", "kimi-k2.5")
|
||||
*/
|
||||
|
||||
export interface TxDeps {
|
||||
/** From @librechat/api — matches a model name to a canonical key. */
|
||||
matchModelName: (model: string, endpoint?: string) => string | undefined;
|
||||
/** From @librechat/api — finds the first key in `values` whose key is a substring of `model`. */
|
||||
findMatchingPattern: (model: string, values: Record<string, unknown>) => string | undefined;
|
||||
}
|
||||
|
||||
export const defaultRate = 6;
|
||||
|
||||
/** AWS Bedrock pricing (source: https://aws.amazon.com/bedrock/pricing/) */
|
||||
const bedrockValues: Record<string, { prompt: number; completion: number }> = {
|
||||
llama2: { prompt: 0.75, completion: 1.0 },
|
||||
'llama-2': { prompt: 0.75, completion: 1.0 },
|
||||
'llama2-13b': { prompt: 0.75, completion: 1.0 },
|
||||
'llama2:70b': { prompt: 1.95, completion: 2.56 },
|
||||
'llama2-70b': { prompt: 1.95, completion: 2.56 },
|
||||
llama3: { prompt: 0.3, completion: 0.6 },
|
||||
'llama-3': { prompt: 0.3, completion: 0.6 },
|
||||
'llama3-8b': { prompt: 0.3, completion: 0.6 },
|
||||
'llama3:8b': { prompt: 0.3, completion: 0.6 },
|
||||
'llama3-70b': { prompt: 2.65, completion: 3.5 },
|
||||
'llama3:70b': { prompt: 2.65, completion: 3.5 },
|
||||
'llama3-1': { prompt: 0.22, completion: 0.22 },
|
||||
'llama3-1-8b': { prompt: 0.22, completion: 0.22 },
|
||||
'llama3-1-70b': { prompt: 0.72, completion: 0.72 },
|
||||
'llama3-1-405b': { prompt: 2.4, completion: 2.4 },
|
||||
'llama3-2': { prompt: 0.1, completion: 0.1 },
|
||||
'llama3-2-1b': { prompt: 0.1, completion: 0.1 },
|
||||
'llama3-2-3b': { prompt: 0.15, completion: 0.15 },
|
||||
'llama3-2-11b': { prompt: 0.16, completion: 0.16 },
|
||||
'llama3-2-90b': { prompt: 0.72, completion: 0.72 },
|
||||
'llama3-3': { prompt: 2.65, completion: 3.5 },
|
||||
'llama3-3-70b': { prompt: 2.65, completion: 3.5 },
|
||||
'llama3.1': { prompt: 0.22, completion: 0.22 },
|
||||
'llama3.1:8b': { prompt: 0.22, completion: 0.22 },
|
||||
'llama3.1:70b': { prompt: 0.72, completion: 0.72 },
|
||||
'llama3.1:405b': { prompt: 2.4, completion: 2.4 },
|
||||
'llama3.2': { prompt: 0.1, completion: 0.1 },
|
||||
'llama3.2:1b': { prompt: 0.1, completion: 0.1 },
|
||||
'llama3.2:3b': { prompt: 0.15, completion: 0.15 },
|
||||
'llama3.2:11b': { prompt: 0.16, completion: 0.16 },
|
||||
'llama3.2:90b': { prompt: 0.72, completion: 0.72 },
|
||||
'llama3.3': { prompt: 2.65, completion: 3.5 },
|
||||
'llama3.3:70b': { prompt: 2.65, completion: 3.5 },
|
||||
'llama-3.1': { prompt: 0.22, completion: 0.22 },
|
||||
'llama-3.1-8b': { prompt: 0.22, completion: 0.22 },
|
||||
'llama-3.1-70b': { prompt: 0.72, completion: 0.72 },
|
||||
'llama-3.1-405b': { prompt: 2.4, completion: 2.4 },
|
||||
'llama-3.2': { prompt: 0.1, completion: 0.1 },
|
||||
'llama-3.2-1b': { prompt: 0.1, completion: 0.1 },
|
||||
'llama-3.2-3b': { prompt: 0.15, completion: 0.15 },
|
||||
'llama-3.2-11b': { prompt: 0.16, completion: 0.16 },
|
||||
'llama-3.2-90b': { prompt: 0.72, completion: 0.72 },
|
||||
'llama-3.3': { prompt: 2.65, completion: 3.5 },
|
||||
'llama-3.3-70b': { prompt: 2.65, completion: 3.5 },
|
||||
'mistral-7b': { prompt: 0.15, completion: 0.2 },
|
||||
'mistral-small': { prompt: 0.15, completion: 0.2 },
|
||||
'mixtral-8x7b': { prompt: 0.45, completion: 0.7 },
|
||||
'mistral-large-2402': { prompt: 4.0, completion: 12.0 },
|
||||
'mistral-large-2407': { prompt: 3.0, completion: 9.0 },
|
||||
'command-text': { prompt: 1.5, completion: 2.0 },
|
||||
'command-light': { prompt: 0.3, completion: 0.6 },
|
||||
'j2-mid': { prompt: 12.5, completion: 12.5 },
|
||||
'j2-ultra': { prompt: 18.8, completion: 18.8 },
|
||||
'jamba-instruct': { prompt: 0.5, completion: 0.7 },
|
||||
'titan-text-lite': { prompt: 0.15, completion: 0.2 },
|
||||
'titan-text-express': { prompt: 0.2, completion: 0.6 },
|
||||
'titan-text-premier': { prompt: 0.5, completion: 1.5 },
|
||||
'nova-micro': { prompt: 0.035, completion: 0.14 },
|
||||
'nova-lite': { prompt: 0.06, completion: 0.24 },
|
||||
'nova-pro': { prompt: 0.8, completion: 3.2 },
|
||||
'nova-premier': { prompt: 2.5, completion: 12.5 },
|
||||
'deepseek.r1': { prompt: 1.35, completion: 5.4 },
|
||||
'moonshot.kimi': { prompt: 0.6, completion: 2.5 },
|
||||
'moonshot.kimi-k2': { prompt: 0.6, completion: 2.5 },
|
||||
'moonshot.kimi-k2.5': { prompt: 0.6, completion: 3.0 },
|
||||
'moonshot.kimi-k2-thinking': { prompt: 0.6, completion: 2.5 },
|
||||
};
|
||||
|
||||
/**
|
||||
* Mapping of model token sizes to their respective multipliers for prompt and completion.
|
||||
* The rates are 1 USD per 1M tokens.
|
||||
*/
|
||||
export const tokenValues: Record<string, { prompt: number; completion: number }> = Object.assign(
|
||||
{
|
||||
'8k': { prompt: 30, completion: 60 },
|
||||
'32k': { prompt: 60, completion: 120 },
|
||||
'4k': { prompt: 1.5, completion: 2 },
|
||||
'16k': { prompt: 3, completion: 4 },
|
||||
'claude-': { prompt: 0.8, completion: 2.4 },
|
||||
deepseek: { prompt: 0.28, completion: 0.42 },
|
||||
command: { prompt: 0.38, completion: 0.38 },
|
||||
gemma: { prompt: 0.02, completion: 0.04 },
|
||||
gemini: { prompt: 0.5, completion: 1.5 },
|
||||
'gpt-oss': { prompt: 0.05, completion: 0.2 },
|
||||
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
|
||||
'gpt-3.5-turbo-0125': { prompt: 0.5, completion: 1.5 },
|
||||
'gpt-4-1106': { prompt: 10, completion: 30 },
|
||||
'gpt-4.1': { prompt: 2, completion: 8 },
|
||||
'gpt-4.1-nano': { prompt: 0.1, completion: 0.4 },
|
||||
'gpt-4.1-mini': { prompt: 0.4, completion: 1.6 },
|
||||
'gpt-4.5': { prompt: 75, completion: 150 },
|
||||
'gpt-4o': { prompt: 2.5, completion: 10 },
|
||||
'gpt-4o-2024-05-13': { prompt: 5, completion: 15 },
|
||||
'gpt-4o-mini': { prompt: 0.15, completion: 0.6 },
|
||||
'gpt-5': { prompt: 1.25, completion: 10 },
|
||||
'gpt-5.1': { prompt: 1.25, completion: 10 },
|
||||
'gpt-5.2': { prompt: 1.75, completion: 14 },
|
||||
'gpt-5-nano': { prompt: 0.05, completion: 0.4 },
|
||||
'gpt-5-mini': { prompt: 0.25, completion: 2 },
|
||||
'gpt-5-pro': { prompt: 15, completion: 120 },
|
||||
o1: { prompt: 15, completion: 60 },
|
||||
'o1-mini': { prompt: 1.1, completion: 4.4 },
|
||||
'o1-preview': { prompt: 15, completion: 60 },
|
||||
o3: { prompt: 2, completion: 8 },
|
||||
'o3-mini': { prompt: 1.1, completion: 4.4 },
|
||||
'o4-mini': { prompt: 1.1, completion: 4.4 },
|
||||
'claude-instant': { prompt: 0.8, completion: 2.4 },
|
||||
'claude-2': { prompt: 8, completion: 24 },
|
||||
'claude-2.1': { prompt: 8, completion: 24 },
|
||||
'claude-3-haiku': { prompt: 0.25, completion: 1.25 },
|
||||
'claude-3-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3-opus': { prompt: 15, completion: 75 },
|
||||
'claude-3-5-haiku': { prompt: 0.8, completion: 4 },
|
||||
'claude-3.5-haiku': { prompt: 0.8, completion: 4 },
|
||||
'claude-3-5-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3.5-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3-7-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3.7-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-haiku-4-5': { prompt: 1, completion: 5 },
|
||||
'claude-opus-4': { prompt: 15, completion: 75 },
|
||||
'claude-opus-4-5': { prompt: 5, completion: 25 },
|
||||
'claude-opus-4-6': { prompt: 5, completion: 25 },
|
||||
'claude-sonnet-4': { prompt: 3, completion: 15 },
|
||||
'claude-sonnet-4-6': { prompt: 3, completion: 15 },
|
||||
'command-r': { prompt: 0.5, completion: 1.5 },
|
||||
'command-r-plus': { prompt: 3, completion: 15 },
|
||||
'command-text': { prompt: 1.5, completion: 2.0 },
|
||||
'deepseek-chat': { prompt: 0.28, completion: 0.42 },
|
||||
'deepseek-reasoner': { prompt: 0.28, completion: 0.42 },
|
||||
'deepseek-r1': { prompt: 0.4, completion: 2.0 },
|
||||
'deepseek-v3': { prompt: 0.2, completion: 0.8 },
|
||||
'gemma-2': { prompt: 0.01, completion: 0.03 },
|
||||
'gemma-3': { prompt: 0.02, completion: 0.04 },
|
||||
'gemma-3-27b': { prompt: 0.09, completion: 0.16 },
|
||||
'gemini-1.5': { prompt: 2.5, completion: 10 },
|
||||
'gemini-1.5-flash': { prompt: 0.15, completion: 0.6 },
|
||||
'gemini-1.5-flash-8b': { prompt: 0.075, completion: 0.3 },
|
||||
'gemini-2.0': { prompt: 0.1, completion: 0.4 },
|
||||
'gemini-2.0-flash': { prompt: 0.1, completion: 0.4 },
|
||||
'gemini-2.0-flash-lite': { prompt: 0.075, completion: 0.3 },
|
||||
'gemini-2.5': { prompt: 0.3, completion: 2.5 },
|
||||
'gemini-2.5-flash': { prompt: 0.3, completion: 2.5 },
|
||||
'gemini-2.5-flash-lite': { prompt: 0.1, completion: 0.4 },
|
||||
'gemini-2.5-pro': { prompt: 1.25, completion: 10 },
|
||||
'gemini-2.5-flash-image': { prompt: 0.15, completion: 30 },
|
||||
'gemini-3': { prompt: 2, completion: 12 },
|
||||
'gemini-3-pro-image': { prompt: 2, completion: 120 },
|
||||
'gemini-pro-vision': { prompt: 0.5, completion: 1.5 },
|
||||
grok: { prompt: 2.0, completion: 10.0 },
|
||||
'grok-beta': { prompt: 5.0, completion: 15.0 },
|
||||
'grok-vision-beta': { prompt: 5.0, completion: 15.0 },
|
||||
'grok-2': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2-1212': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2-latest': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2-vision': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2-vision-1212': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2-vision-latest': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-3': { prompt: 3.0, completion: 15.0 },
|
||||
'grok-3-fast': { prompt: 5.0, completion: 25.0 },
|
||||
'grok-3-mini': { prompt: 0.3, completion: 0.5 },
|
||||
'grok-3-mini-fast': { prompt: 0.6, completion: 4 },
|
||||
'grok-4': { prompt: 3.0, completion: 15.0 },
|
||||
'grok-4-fast': { prompt: 0.2, completion: 0.5 },
|
||||
'grok-4-1-fast': { prompt: 0.2, completion: 0.5 },
|
||||
'grok-code-fast': { prompt: 0.2, completion: 1.5 },
|
||||
codestral: { prompt: 0.3, completion: 0.9 },
|
||||
'ministral-3b': { prompt: 0.04, completion: 0.04 },
|
||||
'ministral-8b': { prompt: 0.1, completion: 0.1 },
|
||||
'mistral-nemo': { prompt: 0.15, completion: 0.15 },
|
||||
'mistral-saba': { prompt: 0.2, completion: 0.6 },
|
||||
'pixtral-large': { prompt: 2.0, completion: 6.0 },
|
||||
'mistral-large': { prompt: 2.0, completion: 6.0 },
|
||||
'mixtral-8x22b': { prompt: 0.65, completion: 0.65 },
|
||||
kimi: { prompt: 0.6, completion: 2.5 },
|
||||
moonshot: { prompt: 2.0, completion: 5.0 },
|
||||
'kimi-latest': { prompt: 0.2, completion: 2.0 },
|
||||
'kimi-k2': { prompt: 0.6, completion: 2.5 },
|
||||
'kimi-k2.5': { prompt: 0.6, completion: 3.0 },
|
||||
'kimi-k2-turbo': { prompt: 1.15, completion: 8.0 },
|
||||
'kimi-k2-turbo-preview': { prompt: 1.15, completion: 8.0 },
|
||||
'kimi-k2-0905': { prompt: 0.6, completion: 2.5 },
|
||||
'kimi-k2-0905-preview': { prompt: 0.6, completion: 2.5 },
|
||||
'kimi-k2-0711': { prompt: 0.6, completion: 2.5 },
|
||||
'kimi-k2-0711-preview': { prompt: 0.6, completion: 2.5 },
|
||||
'kimi-k2-thinking': { prompt: 0.6, completion: 2.5 },
|
||||
'kimi-k2-thinking-turbo': { prompt: 1.15, completion: 8.0 },
|
||||
'moonshot-v1': { prompt: 2.0, completion: 5.0 },
|
||||
'moonshot-v1-auto': { prompt: 2.0, completion: 5.0 },
|
||||
'moonshot-v1-8k': { prompt: 0.2, completion: 2.0 },
|
||||
'moonshot-v1-8k-vision': { prompt: 0.2, completion: 2.0 },
|
||||
'moonshot-v1-8k-vision-preview': { prompt: 0.2, completion: 2.0 },
|
||||
'moonshot-v1-32k': { prompt: 1.0, completion: 3.0 },
|
||||
'moonshot-v1-32k-vision': { prompt: 1.0, completion: 3.0 },
|
||||
'moonshot-v1-32k-vision-preview': { prompt: 1.0, completion: 3.0 },
|
||||
'moonshot-v1-128k': { prompt: 2.0, completion: 5.0 },
|
||||
'moonshot-v1-128k-vision': { prompt: 2.0, completion: 5.0 },
|
||||
'moonshot-v1-128k-vision-preview': { prompt: 2.0, completion: 5.0 },
|
||||
'gpt-oss:20b': { prompt: 0.05, completion: 0.2 },
|
||||
'gpt-oss-20b': { prompt: 0.05, completion: 0.2 },
|
||||
'gpt-oss:120b': { prompt: 0.15, completion: 0.6 },
|
||||
'gpt-oss-120b': { prompt: 0.15, completion: 0.6 },
|
||||
glm4: { prompt: 0.1, completion: 0.1 },
|
||||
'glm-4': { prompt: 0.1, completion: 0.1 },
|
||||
'glm-4-32b': { prompt: 0.1, completion: 0.1 },
|
||||
'glm-4.5': { prompt: 0.35, completion: 1.55 },
|
||||
'glm-4.5-air': { prompt: 0.14, completion: 0.86 },
|
||||
'glm-4.5v': { prompt: 0.6, completion: 1.8 },
|
||||
'glm-4.6': { prompt: 0.5, completion: 1.75 },
|
||||
qwen: { prompt: 0.08, completion: 0.33 },
|
||||
'qwen2.5': { prompt: 0.08, completion: 0.33 },
|
||||
'qwen-turbo': { prompt: 0.05, completion: 0.2 },
|
||||
'qwen-plus': { prompt: 0.4, completion: 1.2 },
|
||||
'qwen-max': { prompt: 1.6, completion: 6.4 },
|
||||
'qwq-32b': { prompt: 0.15, completion: 0.4 },
|
||||
qwen3: { prompt: 0.035, completion: 0.138 },
|
||||
'qwen3-8b': { prompt: 0.035, completion: 0.138 },
|
||||
'qwen3-14b': { prompt: 0.05, completion: 0.22 },
|
||||
'qwen3-30b-a3b': { prompt: 0.06, completion: 0.22 },
|
||||
'qwen3-32b': { prompt: 0.05, completion: 0.2 },
|
||||
'qwen3-235b-a22b': { prompt: 0.08, completion: 0.55 },
|
||||
'qwen3-vl-8b-thinking': { prompt: 0.18, completion: 2.1 },
|
||||
'qwen3-vl-8b-instruct': { prompt: 0.18, completion: 0.69 },
|
||||
'qwen3-vl-30b-a3b': { prompt: 0.29, completion: 1.0 },
|
||||
'qwen3-vl-235b-a22b': { prompt: 0.3, completion: 1.2 },
|
||||
'qwen3-max': { prompt: 1.2, completion: 6 },
|
||||
'qwen3-coder': { prompt: 0.22, completion: 0.95 },
|
||||
'qwen3-coder-30b-a3b': { prompt: 0.06, completion: 0.25 },
|
||||
'qwen3-coder-plus': { prompt: 1, completion: 5 },
|
||||
'qwen3-coder-flash': { prompt: 0.3, completion: 1.5 },
|
||||
'qwen3-next-80b-a3b': { prompt: 0.1, completion: 0.8 },
|
||||
},
|
||||
bedrockValues,
|
||||
);
|
||||
|
||||
/**
|
||||
* Mapping of model token sizes to their respective multipliers for cached input, read and write.
|
||||
* The rates are 1 USD per 1M tokens.
|
||||
*/
|
||||
export const cacheTokenValues: Record<string, { write: number; read: number }> = {
|
||||
'claude-3.7-sonnet': { write: 3.75, read: 0.3 },
|
||||
'claude-3-7-sonnet': { write: 3.75, read: 0.3 },
|
||||
'claude-3.5-sonnet': { write: 3.75, read: 0.3 },
|
||||
'claude-3-5-sonnet': { write: 3.75, read: 0.3 },
|
||||
'claude-3.5-haiku': { write: 1, read: 0.08 },
|
||||
'claude-3-5-haiku': { write: 1, read: 0.08 },
|
||||
'claude-3-haiku': { write: 0.3, read: 0.03 },
|
||||
'claude-haiku-4-5': { write: 1.25, read: 0.1 },
|
||||
'claude-sonnet-4': { write: 3.75, read: 0.3 },
|
||||
'claude-sonnet-4-6': { write: 3.75, read: 0.3 },
|
||||
'claude-opus-4': { write: 18.75, read: 1.5 },
|
||||
'claude-opus-4-5': { write: 6.25, read: 0.5 },
|
||||
'claude-opus-4-6': { write: 6.25, read: 0.5 },
|
||||
deepseek: { write: 0.28, read: 0.028 },
|
||||
'deepseek-chat': { write: 0.28, read: 0.028 },
|
||||
'deepseek-reasoner': { write: 0.28, read: 0.028 },
|
||||
kimi: { write: 0.6, read: 0.15 },
|
||||
'kimi-k2': { write: 0.6, read: 0.15 },
|
||||
'kimi-k2.5': { write: 0.6, read: 0.1 },
|
||||
'kimi-k2-turbo': { write: 1.15, read: 0.15 },
|
||||
'kimi-k2-turbo-preview': { write: 1.15, read: 0.15 },
|
||||
'kimi-k2-0905': { write: 0.6, read: 0.15 },
|
||||
'kimi-k2-0905-preview': { write: 0.6, read: 0.15 },
|
||||
'kimi-k2-0711': { write: 0.6, read: 0.15 },
|
||||
'kimi-k2-0711-preview': { write: 0.6, read: 0.15 },
|
||||
'kimi-k2-thinking': { write: 0.6, read: 0.15 },
|
||||
'kimi-k2-thinking-turbo': { write: 1.15, read: 0.15 },
|
||||
};
|
||||
|
||||
/**
|
||||
* Premium (tiered) pricing for models whose rates change based on prompt size.
|
||||
*/
|
||||
export const premiumTokenValues: Record<
|
||||
string,
|
||||
{ threshold: number; prompt: number; completion: number }
|
||||
> = {
|
||||
'claude-opus-4-6': { threshold: 200000, prompt: 10, completion: 37.5 },
|
||||
'claude-sonnet-4-6': { threshold: 200000, prompt: 6, completion: 22.5 },
|
||||
};
|
||||
|
||||
export function createTxMethods(_mongoose: typeof import('mongoose'), txDeps: TxDeps) {
|
||||
const { matchModelName, findMatchingPattern } = txDeps;
|
||||
|
||||
/**
|
||||
* Retrieves the key associated with a given model name.
|
||||
*/
|
||||
function getValueKey(model: string, endpoint?: string): string | undefined {
|
||||
if (!model || typeof model !== 'string') {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (!endpoint || (typeof endpoint === 'string' && !tokenValues[endpoint])) {
|
||||
const matchedKey = findMatchingPattern(model, tokenValues);
|
||||
if (matchedKey) {
|
||||
return matchedKey;
|
||||
}
|
||||
}
|
||||
|
||||
const modelName = matchModelName(model, endpoint);
|
||||
if (!modelName) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (modelName.includes('gpt-3.5-turbo-16k')) {
|
||||
return '16k';
|
||||
} else if (modelName.includes('gpt-3.5')) {
|
||||
return '4k';
|
||||
} else if (modelName.includes('gpt-4-vision')) {
|
||||
return 'gpt-4-1106';
|
||||
} else if (modelName.includes('gpt-4-0125')) {
|
||||
return 'gpt-4-1106';
|
||||
} else if (modelName.includes('gpt-4-turbo')) {
|
||||
return 'gpt-4-1106';
|
||||
} else if (modelName.includes('gpt-4-32k')) {
|
||||
return '32k';
|
||||
} else if (modelName.includes('gpt-4')) {
|
||||
return '8k';
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if premium (tiered) pricing applies and returns the premium rate.
|
||||
*/
|
||||
function getPremiumRate(
|
||||
valueKey: string,
|
||||
tokenType: string,
|
||||
inputTokenCount?: number,
|
||||
): number | null {
|
||||
if (inputTokenCount == null) {
|
||||
return null;
|
||||
}
|
||||
const premiumEntry = premiumTokenValues[valueKey];
|
||||
if (!premiumEntry || inputTokenCount <= premiumEntry.threshold) {
|
||||
return null;
|
||||
}
|
||||
return premiumEntry[tokenType as 'prompt' | 'completion'] ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the multiplier for a given value key and token type.
|
||||
*/
|
||||
function getMultiplier({
|
||||
model,
|
||||
valueKey,
|
||||
endpoint,
|
||||
tokenType,
|
||||
inputTokenCount,
|
||||
endpointTokenConfig,
|
||||
}: {
|
||||
model?: string;
|
||||
valueKey?: string;
|
||||
endpoint?: string;
|
||||
tokenType?: 'prompt' | 'completion';
|
||||
inputTokenCount?: number;
|
||||
endpointTokenConfig?: Record<string, Record<string, number>>;
|
||||
}): number {
|
||||
if (endpointTokenConfig && model) {
|
||||
return endpointTokenConfig?.[model]?.[tokenType as string] ?? defaultRate;
|
||||
}
|
||||
|
||||
if (valueKey && tokenType) {
|
||||
const premiumRate = getPremiumRate(valueKey, tokenType, inputTokenCount);
|
||||
if (premiumRate != null) {
|
||||
return premiumRate;
|
||||
}
|
||||
return tokenValues[valueKey]?.[tokenType] ?? defaultRate;
|
||||
}
|
||||
|
||||
if (!tokenType || !model) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
valueKey = getValueKey(model, endpoint);
|
||||
if (!valueKey) {
|
||||
return defaultRate;
|
||||
}
|
||||
|
||||
const premiumRate = getPremiumRate(valueKey, tokenType, inputTokenCount);
|
||||
if (premiumRate != null) {
|
||||
return premiumRate;
|
||||
}
|
||||
|
||||
return tokenValues[valueKey]?.[tokenType] ?? defaultRate;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the cache multiplier for a given value key and token type.
|
||||
*/
|
||||
function getCacheMultiplier({
|
||||
valueKey,
|
||||
cacheType,
|
||||
model,
|
||||
endpoint,
|
||||
endpointTokenConfig,
|
||||
}: {
|
||||
valueKey?: string;
|
||||
cacheType?: 'write' | 'read';
|
||||
model?: string;
|
||||
endpoint?: string;
|
||||
endpointTokenConfig?: Record<string, Record<string, number>>;
|
||||
}): number | null {
|
||||
if (endpointTokenConfig && model) {
|
||||
return endpointTokenConfig?.[model]?.[cacheType as string] ?? null;
|
||||
}
|
||||
|
||||
if (valueKey && cacheType) {
|
||||
return cacheTokenValues[valueKey]?.[cacheType] ?? null;
|
||||
}
|
||||
|
||||
if (!cacheType || !model) {
|
||||
return null;
|
||||
}
|
||||
|
||||
valueKey = getValueKey(model, endpoint);
|
||||
if (!valueKey) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return cacheTokenValues[valueKey]?.[cacheType] ?? null;
|
||||
}
|
||||
|
||||
return {
|
||||
tokenValues,
|
||||
premiumTokenValues,
|
||||
getValueKey,
|
||||
getMultiplier,
|
||||
getPremiumRate,
|
||||
getCacheMultiplier,
|
||||
defaultRate,
|
||||
cacheTokenValues,
|
||||
};
|
||||
}
|
||||
|
||||
export type TxMethods = ReturnType<typeof createTxMethods>;
|
||||
|
|
@ -589,6 +589,61 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) {
|
|||
return combined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes a user from all groups they belong to.
|
||||
* @param userId - The user ID (or ObjectId) of the member to remove
|
||||
*/
|
||||
async function removeUserFromAllGroups(userId: string | Types.ObjectId): Promise<void> {
|
||||
const Group = mongoose.models.Group as Model<IGroup>;
|
||||
await Group.updateMany({ memberIds: userId }, { $pullAll: { memberIds: [userId] } });
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds a single group matching the given filter.
|
||||
* @param filter - MongoDB filter query
|
||||
*/
|
||||
async function findGroupByQuery(
|
||||
filter: Record<string, unknown>,
|
||||
session?: ClientSession,
|
||||
): Promise<IGroup | null> {
|
||||
const Group = mongoose.models.Group as Model<IGroup>;
|
||||
const query = Group.findOne(filter);
|
||||
if (session) {
|
||||
query.session(session);
|
||||
}
|
||||
return query.lean();
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a group by its ID.
|
||||
* @param groupId - The group's ObjectId
|
||||
* @param data - Fields to set via $set
|
||||
*/
|
||||
async function updateGroupById(
|
||||
groupId: string | Types.ObjectId,
|
||||
data: Record<string, unknown>,
|
||||
session?: ClientSession,
|
||||
): Promise<IGroup | null> {
|
||||
const Group = mongoose.models.Group as Model<IGroup>;
|
||||
const options = { new: true, ...(session ? { session } : {}) };
|
||||
return Group.findByIdAndUpdate(groupId, { $set: data }, options).lean();
|
||||
}
|
||||
|
||||
/**
|
||||
* Bulk-updates groups matching a filter.
|
||||
* @param filter - MongoDB filter query
|
||||
* @param update - Update operations
|
||||
* @param options - Optional query options (e.g., { session })
|
||||
*/
|
||||
async function bulkUpdateGroups(
|
||||
filter: Record<string, unknown>,
|
||||
update: Record<string, unknown>,
|
||||
options?: { session?: ClientSession },
|
||||
) {
|
||||
const Group = mongoose.models.Group as Model<IGroup>;
|
||||
return Group.updateMany(filter, update, options || {});
|
||||
}
|
||||
|
||||
return {
|
||||
findGroupById,
|
||||
findGroupByExternalId,
|
||||
|
|
@ -598,6 +653,10 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) {
|
|||
upsertGroupByExternalId,
|
||||
addUserToGroup,
|
||||
removeUserFromGroup,
|
||||
removeUserFromAllGroups,
|
||||
findGroupByQuery,
|
||||
updateGroupById,
|
||||
bulkUpdateGroups,
|
||||
getUserGroups,
|
||||
getUserPrincipals,
|
||||
syncUserEntraGroups,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import { Document, Types } from 'mongoose';
|
||||
import type { GraphEdge, AgentToolOptions } from 'librechat-data-provider';
|
||||
import type { GraphEdge, AgentToolOptions, AgentToolResources } from 'librechat-data-provider';
|
||||
|
||||
export interface ISupportContact {
|
||||
name?: string;
|
||||
|
|
@ -32,7 +32,7 @@ export interface IAgent extends Omit<Document, 'model'> {
|
|||
agent_ids?: string[];
|
||||
edges?: GraphEdge[];
|
||||
conversation_starters?: string[];
|
||||
tool_resources?: unknown;
|
||||
tool_resources?: AgentToolResources;
|
||||
versions?: Omit<IAgent, 'versions'>[];
|
||||
category: string;
|
||||
support_contact?: ISupportContact;
|
||||
|
|
|
|||
|
|
@ -10,3 +10,14 @@ export interface IBalance extends Document {
|
|||
lastRefill: Date;
|
||||
refillAmount: number;
|
||||
}
|
||||
|
||||
/** Plain data fields for creating or updating a balance record (no Mongoose Document methods) */
|
||||
export interface IBalanceUpdate {
|
||||
user?: string;
|
||||
tokenCredits?: number;
|
||||
autoRefillEnabled?: boolean;
|
||||
refillIntervalValue?: number;
|
||||
refillIntervalUnit?: string;
|
||||
refillAmount?: number;
|
||||
lastRefill?: Date;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ export interface IMessage extends Document {
|
|||
conversationSignature?: string;
|
||||
clientId?: string;
|
||||
invocationId?: number;
|
||||
parentMessageId?: string;
|
||||
parentMessageId?: string | null;
|
||||
tokenCount?: number;
|
||||
summaryTokenCount?: number;
|
||||
sender?: string;
|
||||
|
|
@ -40,7 +40,7 @@ export interface IMessage extends Document {
|
|||
addedConvo?: boolean;
|
||||
metadata?: Record<string, unknown>;
|
||||
attachments?: unknown[];
|
||||
expiredAt?: Date;
|
||||
expiredAt?: Date | null;
|
||||
createdAt?: Date;
|
||||
updatedAt?: Date;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1 +1,3 @@
|
|||
export * from './string';
|
||||
export * from './tempChatRetention';
|
||||
export * from './transactions';
|
||||
|
|
|
|||
6
packages/data-schemas/src/utils/string.ts
Normal file
6
packages/data-schemas/src/utils/string.ts
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
/**
|
||||
* Escapes special regex characters in a string.
|
||||
*/
|
||||
export function escapeRegExp(str: string): string {
|
||||
return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
import type { AppConfig } from '@librechat/data-schemas';
|
||||
import type { AppConfig } from '~/types';
|
||||
import {
|
||||
createTempChatExpirationDate,
|
||||
getTempChatRetentionHours,
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type { AppConfig } from '@librechat/data-schemas';
|
||||
import logger from '~/config/winston';
|
||||
import type { AppConfig } from '~/types';
|
||||
|
||||
/**
|
||||
* Default retention period for temporary chats in hours
|
||||
10
packages/data-schemas/tsconfig.build.json
Normal file
10
packages/data-schemas/tsconfig.build.json
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"extends": "./tsconfig.json",
|
||||
"compilerOptions": {
|
||||
"noEmit": false,
|
||||
"declaration": true,
|
||||
"declarationDir": "dist/types",
|
||||
"outDir": "dist"
|
||||
},
|
||||
"exclude": ["node_modules", "dist", "**/*.spec.ts"]
|
||||
}
|
||||
|
|
@ -3,9 +3,8 @@
|
|||
"target": "ES2019",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "node",
|
||||
"declaration": true,
|
||||
"declarationDir": "dist/types",
|
||||
"outDir": "dist",
|
||||
"declaration": false,
|
||||
"noEmit": true,
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"allowSyntheticDefaultImports": true,
|
||||
|
|
@ -19,5 +18,5 @@
|
|||
}
|
||||
},
|
||||
"include": ["src/**/*"],
|
||||
"exclude": ["node_modules", "dist", "tests"]
|
||||
"exclude": ["node_modules", "dist"]
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue