mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 00:15:23 +02:00
Merge remote-tracking branch 'upstream/main' into fix/mcp-parser
# Conflicts: # packages/api/src/mcp/parsers.ts
This commit is contained in:
commit
e321792e8b
834 changed files with 62517 additions and 24112 deletions
29
.env.example
29
.env.example
|
|
@ -64,6 +64,8 @@ CONSOLE_JSON=false
|
|||
|
||||
DEBUG_LOGGING=true
|
||||
DEBUG_CONSOLE=false
|
||||
# Set to true to enable agent debug logging
|
||||
AGENT_DEBUG_LOGGING=false
|
||||
|
||||
# Enable memory diagnostics (logs heap/RSS snapshots every 60s, auto-enabled with --inspect)
|
||||
# MEM_DIAG=true
|
||||
|
|
@ -540,6 +542,8 @@ OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for
|
|||
OPENID_USE_END_SESSION_ENDPOINT=
|
||||
# URL to redirect to after OpenID logout (defaults to ${DOMAIN_CLIENT}/login)
|
||||
OPENID_POST_LOGOUT_REDIRECT_URI=
|
||||
# Maximum logout URL length before using logout_hint instead of id_token_hint (default: 2000)
|
||||
OPENID_MAX_LOGOUT_URL_LENGTH=
|
||||
|
||||
#========================#
|
||||
# SharePoint Integration #
|
||||
|
|
@ -623,6 +627,7 @@ EMAIL_PORT=25
|
|||
EMAIL_ENCRYPTION=
|
||||
EMAIL_ENCRYPTION_HOSTNAME=
|
||||
EMAIL_ALLOW_SELFSIGNED=
|
||||
# Leave both empty for SMTP servers that do not require authentication
|
||||
EMAIL_USERNAME=
|
||||
EMAIL_PASSWORD=
|
||||
EMAIL_FROM_NAME=
|
||||
|
|
@ -677,7 +682,8 @@ AZURE_CONTAINER_NAME=files
|
|||
#========================#
|
||||
|
||||
ALLOW_SHARED_LINKS=true
|
||||
ALLOW_SHARED_LINKS_PUBLIC=true
|
||||
# Allows unauthenticated access to shared links. Defaults to false (auth required) if not set.
|
||||
ALLOW_SHARED_LINKS_PUBLIC=false
|
||||
|
||||
#==============================#
|
||||
# Static File Cache Control #
|
||||
|
|
@ -849,3 +855,24 @@ OPENWEATHER_API_KEY=
|
|||
# Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it)
|
||||
# When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration
|
||||
# MCP_SKIP_CODE_CHALLENGE_CHECK=false
|
||||
|
||||
# Circuit breaker: max connect/disconnect cycles before tripping (per server)
|
||||
# MCP_CB_MAX_CYCLES=7
|
||||
|
||||
# Circuit breaker: sliding window (ms) for counting cycles
|
||||
# MCP_CB_CYCLE_WINDOW_MS=45000
|
||||
|
||||
# Circuit breaker: cooldown (ms) after the cycle breaker trips
|
||||
# MCP_CB_CYCLE_COOLDOWN_MS=15000
|
||||
|
||||
# Circuit breaker: max consecutive failed connection rounds before backoff
|
||||
# MCP_CB_MAX_FAILED_ROUNDS=3
|
||||
|
||||
# Circuit breaker: sliding window (ms) for counting failed rounds
|
||||
# MCP_CB_FAILED_WINDOW_MS=120000
|
||||
|
||||
# Circuit breaker: base backoff (ms) after failed round threshold is reached
|
||||
# MCP_CB_BASE_BACKOFF_MS=30000
|
||||
|
||||
# Circuit breaker: max backoff cap (ms) for exponential backoff
|
||||
# MCP_CB_MAX_BACKOFF_MS=300000
|
||||
|
|
|
|||
22
.gitignore
vendored
22
.gitignore
vendored
|
|
@ -154,16 +154,16 @@ claude-flow.config.json
|
|||
.swarm/
|
||||
.hive-mind/
|
||||
.claude-flow/
|
||||
memory/
|
||||
coordination/
|
||||
memory/claude-flow-data.json
|
||||
memory/sessions/*
|
||||
!memory/sessions/README.md
|
||||
memory/agents/*
|
||||
!memory/agents/README.md
|
||||
coordination/memory_bank/*
|
||||
coordination/subtasks/*
|
||||
coordination/orchestration/*
|
||||
/memory/
|
||||
/coordination/
|
||||
/memory/claude-flow-data.json
|
||||
/memory/sessions/*
|
||||
!/memory/sessions/README.md
|
||||
/memory/agents/*
|
||||
!/memory/agents/README.md
|
||||
/coordination/memory_bank/*
|
||||
/coordination/subtasks/*
|
||||
/coordination/orchestration/*
|
||||
*.db
|
||||
*.db-journal
|
||||
*.db-wal
|
||||
|
|
@ -171,5 +171,7 @@ coordination/orchestration/*
|
|||
*.sqlite-journal
|
||||
*.sqlite-wal
|
||||
claude-flow
|
||||
.playwright-mcp/*
|
||||
# Removed Windows wrapper files per user request
|
||||
hive-mind-prompt-*.txt
|
||||
CLAUDE.md
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
#!/bin/sh
|
||||
[ -n "$CI" ] && exit 0
|
||||
npx lint-staged --config ./.husky/lint-staged.config.js
|
||||
|
|
|
|||
10
AGENTS.md
10
AGENTS.md
|
|
@ -149,7 +149,15 @@ Multi-line imports count total character length across all lines. Consolidate va
|
|||
- Run tests from their workspace directory: `cd api && npx jest <pattern>`, `cd packages/api && npx jest <pattern>`, etc.
|
||||
- Frontend tests: `__tests__` directories alongside components; use `test/layout-test-utils` for rendering.
|
||||
- Cover loading, success, and error states for UI/data flows.
|
||||
- Mock data-provider hooks and external dependencies.
|
||||
|
||||
### Philosophy
|
||||
|
||||
- **Real logic over mocks.** Exercise actual code paths with real dependencies. Mocking is a last resort.
|
||||
- **Spies over mocks.** Assert that real functions are called with expected arguments and frequency without replacing underlying logic.
|
||||
- **MongoDB**: use `mongodb-memory-server` for a real in-memory MongoDB instance. Test actual queries and schema validation, not mocked DB calls.
|
||||
- **MCP**: use real `@modelcontextprotocol/sdk` exports for servers, transports, and tool definitions. Mirror real scenarios, don't stub SDK internals.
|
||||
- Only mock what you cannot control: external HTTP APIs, rate-limited services, non-deterministic system calls.
|
||||
- Heavy mocking is a code smell, not a testing strategy.
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
# v0.8.3-rc2
|
||||
# v0.8.4
|
||||
|
||||
# Base node image
|
||||
FROM node:20-alpine AS node
|
||||
|
||||
# Install jemalloc
|
||||
RUN apk upgrade --no-cache
|
||||
RUN apk add --no-cache jemalloc
|
||||
RUN apk add --no-cache python3 py3-pip uv
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
# Dockerfile.multi
|
||||
# v0.8.3-rc2
|
||||
# v0.8.4
|
||||
|
||||
# Set configurable max-old-space-size with default
|
||||
ARG NODE_MAX_OLD_SPACE_SIZE=6144
|
||||
|
||||
# Base for all builds
|
||||
FROM node:20-alpine AS base-min
|
||||
# Install jemalloc
|
||||
RUN apk upgrade --no-cache
|
||||
RUN apk add --no-cache jemalloc
|
||||
# Set environment variable to use jemalloc
|
||||
ENV LD_PRELOAD=/usr/lib/libjemalloc.so.2
|
||||
|
|
|
|||
|
|
@ -7,6 +7,11 @@
|
|||
</h1>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<strong>English</strong> ·
|
||||
<a href="README.zh.md">中文</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.librechat.ai">
|
||||
<img
|
||||
|
|
|
|||
227
README.zh.md
Normal file
227
README.zh.md
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
<!-- Last synced with README.md: 2026-03-20 (e442984364db02163f3cc3ecb7b2ee5efba66fb9) -->
|
||||
|
||||
<p align="center">
|
||||
<a href="https://librechat.ai">
|
||||
<img src="client/public/assets/logo.svg" height="256">
|
||||
</a>
|
||||
<h1 align="center">
|
||||
<a href="https://librechat.ai">LibreChat</a>
|
||||
</h1>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="README.md">English</a> ·
|
||||
<strong>中文</strong>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.librechat.ai">
|
||||
<img
|
||||
src="https://img.shields.io/discord/1086345563026489514?label=&logo=discord&style=for-the-badge&logoWidth=20&logoColor=white&labelColor=000000&color=blueviolet">
|
||||
</a>
|
||||
<a href="https://www.youtube.com/@LibreChat">
|
||||
<img
|
||||
src="https://img.shields.io/badge/YOUTUBE-red.svg?style=for-the-badge&logo=youtube&logoColor=white&labelColor=000000&logoWidth=20">
|
||||
</a>
|
||||
<a href="https://docs.librechat.ai">
|
||||
<img
|
||||
src="https://img.shields.io/badge/DOCS-blue.svg?style=for-the-badge&logo=read-the-docs&logoColor=white&labelColor=000000&logoWidth=20">
|
||||
</a>
|
||||
<a aria-label="Sponsors" href="https://github.com/sponsors/danny-avila">
|
||||
<img
|
||||
src="https://img.shields.io/badge/SPONSORS-brightgreen.svg?style=for-the-badge&logo=github-sponsors&logoColor=white&labelColor=000000&logoWidth=20">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://railway.com/deploy/b5k2mn?referralCode=HI9hWz">
|
||||
<img src="https://railway.com/button.svg" alt="Deploy on Railway" height="30">
|
||||
</a>
|
||||
<a href="https://zeabur.com/templates/0X2ZY8">
|
||||
<img src="https://zeabur.com/button.svg" alt="Deploy on Zeabur" height="30"/>
|
||||
</a>
|
||||
<a href="https://template.cloud.sealos.io/deploy?templateName=librechat">
|
||||
<img src="https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg" alt="Deploy on Sealos" height="30">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.librechat.ai/docs/translation">
|
||||
<img
|
||||
src="https://img.shields.io/badge/dynamic/json.svg?style=for-the-badge&color=2096F3&label=locize&query=%24.translatedPercentage&url=https://api.locize.app/badgedata/4cb2598b-ed4d-469c-9b04-2ed531a8cb45&suffix=%+translated"
|
||||
alt="翻译进度">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
# ✨ 功能
|
||||
|
||||
- 🖥️ **UI 与体验**:受 ChatGPT 启发,并具备更强的设计与功能。
|
||||
|
||||
- 🤖 **AI 模型选择**:
|
||||
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Responses API (包含 Azure)
|
||||
- [自定义端点 (Custom Endpoints)](https://www.librechat.ai/docs/quick_start/custom_endpoints):LibreChat 支持任何兼容 OpenAI 规范的 API,无需代理。
|
||||
- 兼容[本地与远程 AI 服务商](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):
|
||||
- Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai,
|
||||
- OpenRouter, Helicone, Perplexity, ShuttleAI, Deepseek, Qwen 等。
|
||||
|
||||
- 🔧 **[代码解释器 (Code Interpreter) API](https://www.librechat.ai/docs/features/code_interpreter)**:
|
||||
- 安全的沙箱执行环境,支持 Python, Node.js (JS/TS), Go, C/C++, Java, PHP, Rust 和 Fortran。
|
||||
- 无缝文件处理:直接上传、处理并下载文件。
|
||||
- 隐私无忧:完全隔离且安全的执行环境。
|
||||
|
||||
- 🔦 **智能体与工具集成**:
|
||||
- **[LibreChat 智能体 (Agents)](https://www.librechat.ai/docs/features/agents)**:
|
||||
- 无代码定制助手:无需编程即可构建专业化的 AI 驱动助手。
|
||||
- 智能体市场:发现并部署社区构建的智能体。
|
||||
- 协作共享:与特定用户和群组共享智能体。
|
||||
- 灵活且可扩展:支持 MCP 服务器、工具、文件搜索、代码执行等。
|
||||
- 兼容自定义端点、OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API 等。
|
||||
- [支持模型上下文协议 (MCP)](https://modelcontextprotocol.io/clients#librechat) 用于工具调用。
|
||||
|
||||
- 🔍 **网页搜索**:
|
||||
- 搜索互联网并检索相关信息以增强 AI 上下文。
|
||||
- 结合搜索提供商、内容爬虫和结果重排序,确保最佳检索效果。
|
||||
- **可定制 Jina 重排序**:配置自定义 Jina API URL 用于重排序服务。
|
||||
- **[了解更多 →](https://www.librechat.ai/docs/features/web_search)**
|
||||
|
||||
- 🪄 **支持代码 Artifacts 的生成式 UI**:
|
||||
- [代码 Artifacts](https://youtu.be/GfTj7O4gmd0?si=WJbdnemZpJzBrJo3) 允许在对话中直接创建 React 组件、HTML 页面和 Mermaid 图表。
|
||||
|
||||
- 🎨 **图像生成与编辑**:
|
||||
- 使用 [GPT-Image-1](https://www.librechat.ai/docs/features/image_gen#1--openai-image-tools-recommended) 进行文生图与图生图。
|
||||
- 支持 [DALL-E (3/2)](https://www.librechat.ai/docs/features/image_gen#2--dalle-legacy), [Stable Diffusion](https://www.librechat.ai/docs/features/image_gen#3--stable-diffusion-local), [Flux](https://www.librechat.ai/docs/features/image_gen#4--flux) 或任何 [MCP 服务器](https://www.librechat.ai/docs/features/image_gen#5--model-context-protocol-mcp)。
|
||||
- 根据提示词生成惊艳的视觉效果,或通过指令精修现有图像。
|
||||
|
||||
- 💾 **预设与上下文管理**:
|
||||
- 创建、保存并分享自定义预设。
|
||||
- 在对话中随时切换 AI 端点和预设。
|
||||
- 编辑、重新提交并通过对话分支继续消息。
|
||||
- 创建并与特定用户和群组共享提示词。
|
||||
- [消息与对话分叉 (Fork)](https://www.librechat.ai/docs/features/fork) 以实现高级上下文控制。
|
||||
|
||||
- 💬 **多模态与文件交互**:
|
||||
- 使用 Claude 3, GPT-4.5, GPT-4o, o1, Llama-Vision 和 Gemini 上传并分析图像 📸。
|
||||
- 支持通过自定义端点、OpenAI, Azure, Anthropic, AWS Bedrock 和 Google 进行文件对话 🗃️。
|
||||
|
||||
- 🌎 **多语言 UI**:
|
||||
- English, 中文 (简体), 中文 (繁體), العربية, Deutsch, Español, Français, Italiano
|
||||
- Polski, Português (PT), Português (BR), Русский, 日本語, Svenska, 한국어, Tiếng Việt
|
||||
- Türkçe, Nederlands, עברית, Català, Čeština, Dansk, Eesti, فارسی
|
||||
- Suomi, Magyar, Հայերեն, Bahasa Indonesia, ქართული, Latviešu, ไทย, ئۇيغۇرچە
|
||||
|
||||
- 🧠 **推理 UI**:
|
||||
- 针对 DeepSeek-R1 等思维链/推理 AI 模型的动态推理 UI。
|
||||
|
||||
- 🎨 **可定制界面**:
|
||||
- 可定制的下拉菜单和界面,同时适配高级用户和初学者。
|
||||
|
||||
- 🌊 **[可恢复流 (Resumable Streams)](https://www.librechat.ai/docs/features/resumable_streams)**:
|
||||
- 永不丢失响应:AI 响应在连接中断后自动重连并继续。
|
||||
- 多标签页与多设备同步:在多个标签页打开同一对话,或在另一设备上继续。
|
||||
- 生产级可靠性:支持从单机部署到基于 Redis 的水平扩展。
|
||||
|
||||
- 🗣️ **语音与音频**:
|
||||
- 通过语音转文字和文字转语音实现免提对话。
|
||||
- 自动发送并播放音频。
|
||||
- 支持 OpenAI, Azure OpenAI 和 Elevenlabs。
|
||||
|
||||
- 📥 **导入与导出对话**:
|
||||
- 从 LibreChat, ChatGPT, Chatbot UI 导入对话。
|
||||
- 将对话导出为截图、Markdown、文本、JSON。
|
||||
|
||||
- 🔍 **搜索与发现**:
|
||||
- 搜索所有消息和对话。
|
||||
|
||||
- 👥 **多用户与安全访问**:
|
||||
- 支持 OAuth2, LDAP 和电子邮件登录的多用户安全认证。
|
||||
- 内置审核系统和 Token 消耗管理工具。
|
||||
|
||||
- ⚙️ **配置与部署**:
|
||||
- 支持代理、反向代理、Docker 及多种部署选项。
|
||||
- 可完全本地运行或部署在云端。
|
||||
|
||||
- 📖 **开源与社区**:
|
||||
- 完全开源且在公众监督下开发。
|
||||
- 社区驱动的开发、支持与反馈。
|
||||
|
||||
[查看我们的文档了解更多功能详情](https://docs.librechat.ai/) 📚
|
||||
|
||||
## 🪶 LibreChat:全方位的 AI 对话平台
|
||||
|
||||
LibreChat 是一个自托管的 AI 对话平台,在一个注重隐私的统一界面中整合了所有主流 AI 服务商。
|
||||
|
||||
除了对话功能外,LibreChat 还提供 AI 智能体、模型上下文协议 (MCP) 支持、Artifacts、代码解释器、自定义操作、对话搜索,以及企业级多用户认证。
|
||||
|
||||
开源、活跃开发中,专为重视 AI 基础设施自主可控的用户而构建。
|
||||
|
||||
---
|
||||
|
||||
## 🌐 资源
|
||||
|
||||
**GitHub 仓库:**
|
||||
- **RAG API:** [github.com/danny-avila/rag_api](https://github.com/danny-avila/rag_api)
|
||||
- **网站:** [github.com/LibreChat-AI/librechat.ai](https://github.com/LibreChat-AI/librechat.ai)
|
||||
|
||||
**其他:**
|
||||
- **官方网站:** [librechat.ai](https://librechat.ai)
|
||||
- **帮助文档:** [librechat.ai/docs](https://librechat.ai/docs)
|
||||
- **博客:** [librechat.ai/blog](https://librechat.ai/blog)
|
||||
|
||||
---
|
||||
|
||||
## 📝 更新日志
|
||||
|
||||
访问发布页面和更新日志以了解最新动态:
|
||||
- [发布页面 (Releases)](https://github.com/danny-avila/LibreChat/releases)
|
||||
- [更新日志 (Changelog)](https://www.librechat.ai/changelog)
|
||||
|
||||
**⚠️ 在更新前请务必查看[更新日志](https://www.librechat.ai/changelog)以了解破坏性更改。**
|
||||
|
||||
---
|
||||
|
||||
## ⭐ Star 历史
|
||||
|
||||
<p align="center">
|
||||
<a href="https://star-history.com/#danny-avila/LibreChat&Date">
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=danny-avila/LibreChat&type=Date&theme=dark" onerror="this.src='https://api.star-history.com/svg?repos=danny-avila/LibreChat&type=Date'" />
|
||||
</a>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/4685" target="_blank" style="padding: 10px;">
|
||||
<img src="https://trendshift.io/api/badge/repositories/4685" alt="danny-avila%2FLibreChat | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
<a href="https://runacap.com/ross-index/q1-24/" target="_blank" rel="noopener" style="margin-left: 20px;">
|
||||
<img style="width: 260px; height: 56px" src="https://runacap.com/wp-content/uploads/2024/04/ROSS_badge_white_Q1_2024.svg" alt="ROSS Index - 2024年第一季度增长最快的开源初创公司 | Runa Capital" width="260" height="56"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## ✨ 贡献
|
||||
|
||||
欢迎任何形式的贡献、建议、错误报告和修复!
|
||||
|
||||
对于新功能、组件或扩展,请在发送 PR 前开启 issue 进行讨论。
|
||||
|
||||
如果您想帮助我们将 LibreChat 翻译成您的母语,我们非常欢迎!改进翻译不仅能让全球用户更轻松地使用 LibreChat,还能提升整体用户体验。请查看我们的[翻译指南](https://www.librechat.ai/docs/translation)。
|
||||
|
||||
---
|
||||
|
||||
## 💖 感谢所有贡献者
|
||||
|
||||
<a href="https://github.com/danny-avila/LibreChat/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=danny-avila/LibreChat" />
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## 🎉 特别鸣谢
|
||||
|
||||
感谢 [Locize](https://locize.com) 提供的翻译管理工具,支持 LibreChat 的多语言功能。
|
||||
|
||||
<p align="center">
|
||||
<a href="https://locize.com" target="_blank" rel="noopener noreferrer">
|
||||
<img src="https://github.com/user-attachments/assets/d6b70894-6064-475e-bb65-92a9e23e0077" alt="Locize Logo" height="50">
|
||||
</a>
|
||||
</p>
|
||||
|
|
@ -3,6 +3,7 @@ const fetch = require('node-fetch');
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
countTokens,
|
||||
checkBalance,
|
||||
getBalanceConfig,
|
||||
buildMessageFiles,
|
||||
extractFileContext,
|
||||
|
|
@ -12,7 +13,6 @@ const {
|
|||
} = require('@librechat/api');
|
||||
const {
|
||||
Constants,
|
||||
ErrorTypes,
|
||||
FileSources,
|
||||
ContentTypes,
|
||||
excludedKeys,
|
||||
|
|
@ -23,18 +23,10 @@ const {
|
|||
supportsBalanceCheck,
|
||||
isBedrockDocumentType,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
updateMessage,
|
||||
getMessages,
|
||||
saveMessage,
|
||||
saveConvo,
|
||||
getConvo,
|
||||
getFiles,
|
||||
} = require('~/models');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { truncateToolCallOutputs } = require('./prompts');
|
||||
const { logViolation } = require('~/cache');
|
||||
const TextStream = require('./TextStream');
|
||||
const db = require('~/models');
|
||||
|
||||
class BaseClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
|
|
@ -339,45 +331,6 @@ class BaseClient {
|
|||
return payload;
|
||||
}
|
||||
|
||||
async handleTokenCountMap(tokenCountMap) {
|
||||
if (this.clientName === EModelEndpoint.agents) {
|
||||
return;
|
||||
}
|
||||
if (this.currentMessages.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (let i = 0; i < this.currentMessages.length; i++) {
|
||||
// Skip the last message, which is the user message.
|
||||
if (i === this.currentMessages.length - 1) {
|
||||
break;
|
||||
}
|
||||
|
||||
const message = this.currentMessages[i];
|
||||
const { messageId } = message;
|
||||
const update = {};
|
||||
|
||||
if (messageId === tokenCountMap.summaryMessage?.messageId) {
|
||||
logger.debug(`[BaseClient] Adding summary props to ${messageId}.`);
|
||||
|
||||
update.summary = tokenCountMap.summaryMessage.content;
|
||||
update.summaryTokenCount = tokenCountMap.summaryMessage.tokenCount;
|
||||
}
|
||||
|
||||
if (message.tokenCount && !update.summaryTokenCount) {
|
||||
logger.debug(`[BaseClient] Skipping ${messageId}: already had a token count.`);
|
||||
continue;
|
||||
}
|
||||
|
||||
const tokenCount = tokenCountMap[messageId];
|
||||
if (tokenCount) {
|
||||
message.tokenCount = tokenCount;
|
||||
update.tokenCount = tokenCount;
|
||||
await this.updateMessageInDatabase({ messageId, ...update });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
concatenateMessages(messages) {
|
||||
return messages.reduce((acc, message) => {
|
||||
const nameOrRole = message.name ?? message.role;
|
||||
|
|
@ -448,154 +401,6 @@ class BaseClient {
|
|||
};
|
||||
}
|
||||
|
||||
async handleContextStrategy({
|
||||
instructions,
|
||||
orderedMessages,
|
||||
formattedMessages,
|
||||
buildTokenMap = true,
|
||||
}) {
|
||||
let _instructions;
|
||||
let tokenCount;
|
||||
|
||||
if (instructions) {
|
||||
({ tokenCount, ..._instructions } = instructions);
|
||||
}
|
||||
|
||||
_instructions && logger.debug('[BaseClient] instructions tokenCount: ' + tokenCount);
|
||||
if (tokenCount && tokenCount > this.maxContextTokens) {
|
||||
const info = `${tokenCount} / ${this.maxContextTokens}`;
|
||||
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
|
||||
logger.warn(`Instructions token count exceeds max token count (${info}).`);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
if (this.clientName === EModelEndpoint.agents) {
|
||||
const { dbMessages, editedIndices } = truncateToolCallOutputs(
|
||||
orderedMessages,
|
||||
this.maxContextTokens,
|
||||
this.getTokenCountForMessage.bind(this),
|
||||
);
|
||||
|
||||
if (editedIndices.length > 0) {
|
||||
logger.debug('[BaseClient] Truncated tool call outputs:', editedIndices);
|
||||
for (const index of editedIndices) {
|
||||
formattedMessages[index].content = dbMessages[index].content;
|
||||
}
|
||||
orderedMessages = dbMessages;
|
||||
}
|
||||
}
|
||||
|
||||
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
|
||||
|
||||
let { context, remainingContextTokens, messagesToRefine } =
|
||||
await this.getMessagesWithinTokenLimit({
|
||||
messages: orderedWithInstructions,
|
||||
instructions,
|
||||
});
|
||||
|
||||
logger.debug('[BaseClient] Context Count (1/2)', {
|
||||
remainingContextTokens,
|
||||
maxContextTokens: this.maxContextTokens,
|
||||
});
|
||||
|
||||
let summaryMessage;
|
||||
let summaryTokenCount;
|
||||
let { shouldSummarize } = this;
|
||||
|
||||
// Calculate the difference in length to determine how many messages were discarded if any
|
||||
let payload;
|
||||
let { length } = formattedMessages;
|
||||
length += instructions != null ? 1 : 0;
|
||||
const diff = length - context.length;
|
||||
const firstMessage = orderedWithInstructions[0];
|
||||
const usePrevSummary =
|
||||
shouldSummarize &&
|
||||
diff === 1 &&
|
||||
firstMessage?.summary &&
|
||||
this.previous_summary.messageId === firstMessage.messageId;
|
||||
|
||||
if (diff > 0) {
|
||||
payload = formattedMessages.slice(diff);
|
||||
logger.debug(
|
||||
`[BaseClient] Difference between original payload (${length}) and context (${context.length}): ${diff}`,
|
||||
);
|
||||
}
|
||||
|
||||
payload = this.addInstructions(payload ?? formattedMessages, _instructions);
|
||||
|
||||
const latestMessage = orderedWithInstructions[orderedWithInstructions.length - 1];
|
||||
if (payload.length === 0 && !shouldSummarize && latestMessage) {
|
||||
const info = `${latestMessage.tokenCount} / ${this.maxContextTokens}`;
|
||||
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
|
||||
logger.warn(`Prompt token count exceeds max token count (${info}).`);
|
||||
throw new Error(errorMessage);
|
||||
} else if (
|
||||
_instructions &&
|
||||
payload.length === 1 &&
|
||||
payload[0].content === _instructions.content
|
||||
) {
|
||||
const info = `${tokenCount + 3} / ${this.maxContextTokens}`;
|
||||
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
|
||||
logger.warn(
|
||||
`Including instructions, the prompt token count exceeds remaining max token count (${info}).`,
|
||||
);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
if (usePrevSummary) {
|
||||
summaryMessage = { role: 'system', content: firstMessage.summary };
|
||||
summaryTokenCount = firstMessage.summaryTokenCount;
|
||||
payload.unshift(summaryMessage);
|
||||
remainingContextTokens -= summaryTokenCount;
|
||||
} else if (shouldSummarize && messagesToRefine.length > 0) {
|
||||
({ summaryMessage, summaryTokenCount } = await this.summarizeMessages({
|
||||
messagesToRefine,
|
||||
remainingContextTokens,
|
||||
}));
|
||||
summaryMessage && payload.unshift(summaryMessage);
|
||||
remainingContextTokens -= summaryTokenCount;
|
||||
}
|
||||
|
||||
// Make sure to only continue summarization logic if the summary message was generated
|
||||
shouldSummarize = summaryMessage != null && shouldSummarize === true;
|
||||
|
||||
logger.debug('[BaseClient] Context Count (2/2)', {
|
||||
remainingContextTokens,
|
||||
maxContextTokens: this.maxContextTokens,
|
||||
});
|
||||
|
||||
/** @type {Record<string, number> | undefined} */
|
||||
let tokenCountMap;
|
||||
if (buildTokenMap) {
|
||||
const currentPayload = shouldSummarize ? orderedWithInstructions : context;
|
||||
tokenCountMap = currentPayload.reduce((map, message, index) => {
|
||||
const { messageId } = message;
|
||||
if (!messageId) {
|
||||
return map;
|
||||
}
|
||||
|
||||
if (shouldSummarize && index === messagesToRefine.length - 1 && !usePrevSummary) {
|
||||
map.summaryMessage = { ...summaryMessage, messageId, tokenCount: summaryTokenCount };
|
||||
}
|
||||
|
||||
map[messageId] = currentPayload[index].tokenCount;
|
||||
return map;
|
||||
}, {});
|
||||
}
|
||||
|
||||
const promptTokens = this.maxContextTokens - remainingContextTokens;
|
||||
|
||||
logger.debug('[BaseClient] tokenCountMap:', tokenCountMap);
|
||||
logger.debug('[BaseClient]', {
|
||||
promptTokens,
|
||||
remainingContextTokens,
|
||||
payloadSize: payload.length,
|
||||
maxContextTokens: this.maxContextTokens,
|
||||
});
|
||||
|
||||
return { payload, tokenCountMap, promptTokens, messages: orderedWithInstructions };
|
||||
}
|
||||
|
||||
async sendMessage(message, opts = {}) {
|
||||
const appConfig = this.options.req?.config;
|
||||
/** @type {Promise<TMessage>} */
|
||||
|
|
@ -664,17 +469,13 @@ class BaseClient {
|
|||
opts,
|
||||
);
|
||||
|
||||
if (tokenCountMap) {
|
||||
if (tokenCountMap[userMessage.messageId]) {
|
||||
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
|
||||
logger.debug('[BaseClient] userMessage', {
|
||||
messageId: userMessage.messageId,
|
||||
tokenCount: userMessage.tokenCount,
|
||||
conversationId: userMessage.conversationId,
|
||||
});
|
||||
}
|
||||
|
||||
this.handleTokenCountMap(tokenCountMap);
|
||||
if (tokenCountMap && tokenCountMap[userMessage.messageId]) {
|
||||
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
|
||||
logger.debug('[BaseClient] userMessage', {
|
||||
messageId: userMessage.messageId,
|
||||
tokenCount: userMessage.tokenCount,
|
||||
conversationId: userMessage.conversationId,
|
||||
});
|
||||
}
|
||||
|
||||
if (!isEdited && !this.skipSaveUserMessage) {
|
||||
|
|
@ -700,18 +501,26 @@ class BaseClient {
|
|||
balanceConfig?.enabled &&
|
||||
supportsBalanceCheck[this.options.endpointType ?? this.options.endpoint]
|
||||
) {
|
||||
await checkBalance({
|
||||
req: this.options.req,
|
||||
res: this.options.res,
|
||||
txData: {
|
||||
user: this.user,
|
||||
tokenType: 'prompt',
|
||||
amount: promptTokens,
|
||||
endpoint: this.options.endpoint,
|
||||
model: this.modelOptions?.model ?? this.model,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
await checkBalance(
|
||||
{
|
||||
req: this.options.req,
|
||||
res: this.options.res,
|
||||
txData: {
|
||||
user: this.user,
|
||||
tokenType: 'prompt',
|
||||
amount: promptTokens,
|
||||
endpoint: this.options.endpoint,
|
||||
model: this.modelOptions?.model ?? this.model,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
},
|
||||
},
|
||||
});
|
||||
{
|
||||
logViolation,
|
||||
getMultiplier: db.getMultiplier,
|
||||
findBalanceByUser: db.findBalanceByUser,
|
||||
createAutoRefillTransaction: db.createAutoRefillTransaction,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
const { completion, metadata } = await this.sendCompletion(payload, opts);
|
||||
|
|
@ -764,12 +573,7 @@ class BaseClient {
|
|||
responseMessage.text = completion.join('');
|
||||
}
|
||||
|
||||
if (
|
||||
tokenCountMap &&
|
||||
this.recordTokenUsage &&
|
||||
this.getTokenCountForResponse &&
|
||||
this.getTokenCount
|
||||
) {
|
||||
if (tokenCountMap && this.recordTokenUsage && this.getTokenCountForResponse) {
|
||||
let completionTokens;
|
||||
|
||||
/**
|
||||
|
|
@ -782,13 +586,6 @@ class BaseClient {
|
|||
if (usage != null && Number(usage[this.outputTokensKey]) > 0) {
|
||||
responseMessage.tokenCount = usage[this.outputTokensKey];
|
||||
completionTokens = responseMessage.tokenCount;
|
||||
await this.updateUserMessageTokenCount({
|
||||
usage,
|
||||
tokenCountMap,
|
||||
userMessage,
|
||||
userMessagePromise,
|
||||
opts,
|
||||
});
|
||||
} else {
|
||||
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
|
||||
completionTokens = responseMessage.tokenCount;
|
||||
|
|
@ -815,6 +612,27 @@ class BaseClient {
|
|||
await userMessagePromise;
|
||||
}
|
||||
|
||||
if (
|
||||
this.contextMeta?.calibrationRatio > 0 &&
|
||||
this.contextMeta.calibrationRatio !== 1 &&
|
||||
userMessage.tokenCount > 0
|
||||
) {
|
||||
const calibrated = Math.round(userMessage.tokenCount * this.contextMeta.calibrationRatio);
|
||||
if (calibrated !== userMessage.tokenCount) {
|
||||
logger.debug('[BaseClient] Calibrated user message tokenCount', {
|
||||
messageId: userMessage.messageId,
|
||||
raw: userMessage.tokenCount,
|
||||
calibrated,
|
||||
ratio: this.contextMeta.calibrationRatio,
|
||||
});
|
||||
userMessage.tokenCount = calibrated;
|
||||
await this.updateMessageInDatabase({
|
||||
messageId: userMessage.messageId,
|
||||
tokenCount: calibrated,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (this.artifactPromises) {
|
||||
responseMessage.attachments = (await Promise.all(this.artifactPromises)).filter((a) => a);
|
||||
}
|
||||
|
|
@ -827,6 +645,10 @@ class BaseClient {
|
|||
}
|
||||
}
|
||||
|
||||
if (this.contextMeta) {
|
||||
responseMessage.contextMeta = this.contextMeta;
|
||||
}
|
||||
|
||||
responseMessage.databasePromise = this.saveMessageToDatabase(
|
||||
responseMessage,
|
||||
saveOptions,
|
||||
|
|
@ -837,79 +659,10 @@ class BaseClient {
|
|||
return responseMessage;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream usage should only be used for user message token count re-calculation if:
|
||||
* - The stream usage is available, with input tokens greater than 0,
|
||||
* - the client provides a function to calculate the current token count,
|
||||
* - files are being resent with every message (default behavior; or if `false`, with no attachments),
|
||||
* - the `promptPrefix` (custom instructions) is not set.
|
||||
*
|
||||
* In these cases, the legacy token estimations would be more accurate.
|
||||
*
|
||||
* TODO: included system messages in the `orderedMessages` accounting, potentially as a
|
||||
* separate message in the UI. ChatGPT does this through "hidden" system messages.
|
||||
* @param {object} params
|
||||
* @param {StreamUsage} params.usage
|
||||
* @param {Record<string, number>} params.tokenCountMap
|
||||
* @param {TMessage} params.userMessage
|
||||
* @param {Promise<TMessage>} params.userMessagePromise
|
||||
* @param {object} params.opts
|
||||
*/
|
||||
async updateUserMessageTokenCount({
|
||||
usage,
|
||||
tokenCountMap,
|
||||
userMessage,
|
||||
userMessagePromise,
|
||||
opts,
|
||||
}) {
|
||||
/** @type {boolean} */
|
||||
const shouldUpdateCount =
|
||||
this.calculateCurrentTokenCount != null &&
|
||||
Number(usage[this.inputTokensKey]) > 0 &&
|
||||
(this.options.resendFiles ||
|
||||
(!this.options.resendFiles && !this.options.attachments?.length)) &&
|
||||
!this.options.promptPrefix;
|
||||
|
||||
if (!shouldUpdateCount) {
|
||||
return;
|
||||
}
|
||||
|
||||
const userMessageTokenCount = this.calculateCurrentTokenCount({
|
||||
currentMessageId: userMessage.messageId,
|
||||
tokenCountMap,
|
||||
usage,
|
||||
});
|
||||
|
||||
if (userMessageTokenCount === userMessage.tokenCount) {
|
||||
return;
|
||||
}
|
||||
|
||||
userMessage.tokenCount = userMessageTokenCount;
|
||||
/*
|
||||
Note: `AgentController` saves the user message if not saved here
|
||||
(noted by `savedMessageIds`), so we update the count of its `userMessage` reference
|
||||
*/
|
||||
if (typeof opts?.getReqData === 'function') {
|
||||
opts.getReqData({
|
||||
userMessage,
|
||||
});
|
||||
}
|
||||
/*
|
||||
Note: we update the user message to be sure it gets the calculated token count;
|
||||
though `AgentController` saves the user message if not saved here
|
||||
(noted by `savedMessageIds`), EditController does not
|
||||
*/
|
||||
await userMessagePromise;
|
||||
await this.updateMessageInDatabase({
|
||||
messageId: userMessage.messageId,
|
||||
tokenCount: userMessageTokenCount,
|
||||
});
|
||||
}
|
||||
|
||||
async loadHistory(conversationId, parentMessageId = null) {
|
||||
logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId });
|
||||
|
||||
const messages = (await getMessages({ conversationId })) ?? [];
|
||||
const messages = (await db.getMessages({ conversationId })) ?? [];
|
||||
|
||||
if (messages.length === 0) {
|
||||
return [];
|
||||
|
|
@ -932,10 +685,24 @@ class BaseClient {
|
|||
return _messages;
|
||||
}
|
||||
|
||||
// Find the latest message with a 'summary' property
|
||||
for (let i = _messages.length - 1; i >= 0; i--) {
|
||||
if (_messages[i]?.summary) {
|
||||
this.previous_summary = _messages[i];
|
||||
const msg = _messages[i];
|
||||
if (!msg) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const summaryBlock = BaseClient.findSummaryContentBlock(msg);
|
||||
if (summaryBlock) {
|
||||
this.previous_summary = {
|
||||
...msg,
|
||||
summary: BaseClient.getSummaryText(summaryBlock),
|
||||
summaryTokenCount: summaryBlock.tokenCount,
|
||||
};
|
||||
break;
|
||||
}
|
||||
|
||||
if (msg.summary) {
|
||||
this.previous_summary = msg;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -965,8 +732,13 @@ class BaseClient {
|
|||
}
|
||||
|
||||
const hasAddedConvo = this.options?.req?.body?.addedConvo != null;
|
||||
const savedMessage = await saveMessage(
|
||||
this.options?.req,
|
||||
const reqCtx = {
|
||||
userId: this.options?.req?.user?.id,
|
||||
isTemporary: this.options?.req?.body?.isTemporary,
|
||||
interfaceConfig: this.options?.req?.config?.interfaceConfig,
|
||||
};
|
||||
const savedMessage = await db.saveMessage(
|
||||
reqCtx,
|
||||
{
|
||||
...message,
|
||||
endpoint: this.options.endpoint,
|
||||
|
|
@ -991,7 +763,7 @@ class BaseClient {
|
|||
const existingConvo =
|
||||
this.fetchedConvo === true
|
||||
? null
|
||||
: await getConvo(this.options?.req?.user?.id, message.conversationId);
|
||||
: await db.getConvo(this.options?.req?.user?.id, message.conversationId);
|
||||
|
||||
const unsetFields = {};
|
||||
const exceptions = new Set(['spec', 'iconURL']);
|
||||
|
|
@ -1018,7 +790,7 @@ class BaseClient {
|
|||
}
|
||||
}
|
||||
|
||||
const conversation = await saveConvo(this.options?.req, fieldsToKeep, {
|
||||
const conversation = await db.saveConvo(reqCtx, fieldsToKeep, {
|
||||
context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo',
|
||||
unsetFields,
|
||||
});
|
||||
|
|
@ -1031,7 +803,35 @@ class BaseClient {
|
|||
* @param {Partial<TMessage>} message
|
||||
*/
|
||||
async updateMessageInDatabase(message) {
|
||||
await updateMessage(this.options.req, message);
|
||||
await db.updateMessage(this.options?.req?.user?.id, message);
|
||||
}
|
||||
|
||||
/** Extracts text from a summary block (handles both legacy `text` field and new `content` array format). */
|
||||
static getSummaryText(summaryBlock) {
|
||||
if (Array.isArray(summaryBlock.content)) {
|
||||
return summaryBlock.content.map((b) => b.text ?? '').join('');
|
||||
}
|
||||
if (typeof summaryBlock.content === 'string') {
|
||||
return summaryBlock.content;
|
||||
}
|
||||
return summaryBlock.text ?? '';
|
||||
}
|
||||
|
||||
/** Finds the last summary content block in a message's content array (last-summary-wins). */
|
||||
static findSummaryContentBlock(message) {
|
||||
if (!Array.isArray(message?.content)) {
|
||||
return null;
|
||||
}
|
||||
let lastSummary = null;
|
||||
for (const part of message.content) {
|
||||
if (
|
||||
part?.type === ContentTypes.SUMMARY &&
|
||||
BaseClient.getSummaryText(part).trim().length > 0
|
||||
) {
|
||||
lastSummary = part;
|
||||
}
|
||||
}
|
||||
return lastSummary;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -1088,20 +888,35 @@ class BaseClient {
|
|||
break;
|
||||
}
|
||||
|
||||
if (summary && message.summary) {
|
||||
message.role = 'system';
|
||||
message.text = message.summary;
|
||||
let resolved = message;
|
||||
let hasSummary = false;
|
||||
if (summary) {
|
||||
const summaryBlock = BaseClient.findSummaryContentBlock(message);
|
||||
if (summaryBlock) {
|
||||
const summaryText = BaseClient.getSummaryText(summaryBlock);
|
||||
resolved = {
|
||||
...message,
|
||||
role: 'system',
|
||||
content: [{ type: ContentTypes.TEXT, text: summaryText }],
|
||||
tokenCount: summaryBlock.tokenCount,
|
||||
};
|
||||
hasSummary = true;
|
||||
} else if (message.summary) {
|
||||
resolved = {
|
||||
...message,
|
||||
role: 'system',
|
||||
content: [{ type: ContentTypes.TEXT, text: message.summary }],
|
||||
tokenCount: message.summaryTokenCount ?? message.tokenCount,
|
||||
};
|
||||
hasSummary = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (summary && message.summaryTokenCount) {
|
||||
message.tokenCount = message.summaryTokenCount;
|
||||
}
|
||||
|
||||
const shouldMap = mapMethod != null && (mapCondition != null ? mapCondition(message) : true);
|
||||
const processedMessage = shouldMap ? mapMethod(message) : message;
|
||||
const shouldMap = mapMethod != null && (mapCondition != null ? mapCondition(resolved) : true);
|
||||
const processedMessage = shouldMap ? mapMethod(resolved) : resolved;
|
||||
orderedMessages.push(processedMessage);
|
||||
|
||||
if (summary && message.summary) {
|
||||
if (hasSummary) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -1431,7 +1246,7 @@ class BaseClient {
|
|||
return message;
|
||||
}
|
||||
|
||||
const files = await getFiles(
|
||||
const files = await db.getFiles(
|
||||
{
|
||||
file_id: { $in: fileIds },
|
||||
},
|
||||
|
|
|
|||
|
|
@ -37,79 +37,4 @@ function smartTruncateText(text, maxLength = MAX_CHAR) {
|
|||
return text;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {TMessage[]} _messages
|
||||
* @param {number} maxContextTokens
|
||||
* @param {function({role: string, content: TMessageContent[]}): number} getTokenCountForMessage
|
||||
*
|
||||
* @returns {{
|
||||
* dbMessages: TMessage[],
|
||||
* editedIndices: number[]
|
||||
* }}
|
||||
*/
|
||||
function truncateToolCallOutputs(_messages, maxContextTokens, getTokenCountForMessage) {
|
||||
const THRESHOLD_PERCENTAGE = 0.5;
|
||||
const targetTokenLimit = maxContextTokens * THRESHOLD_PERCENTAGE;
|
||||
|
||||
let currentTokenCount = 3;
|
||||
const messages = [..._messages];
|
||||
const processedMessages = [];
|
||||
let currentIndex = messages.length;
|
||||
const editedIndices = new Set();
|
||||
while (messages.length > 0) {
|
||||
currentIndex--;
|
||||
const message = messages.pop();
|
||||
currentTokenCount += message.tokenCount;
|
||||
if (currentTokenCount < targetTokenLimit) {
|
||||
processedMessages.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!message.content || !Array.isArray(message.content)) {
|
||||
processedMessages.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolCallIndices = message.content
|
||||
.map((item, index) => (item.type === 'tool_call' ? index : -1))
|
||||
.filter((index) => index !== -1)
|
||||
.reverse();
|
||||
|
||||
if (toolCallIndices.length === 0) {
|
||||
processedMessages.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
const newContent = [...message.content];
|
||||
|
||||
// Truncate all tool outputs since we're over threshold
|
||||
for (const index of toolCallIndices) {
|
||||
const toolCall = newContent[index].tool_call;
|
||||
if (!toolCall || !toolCall.output) {
|
||||
continue;
|
||||
}
|
||||
|
||||
editedIndices.add(currentIndex);
|
||||
|
||||
newContent[index] = {
|
||||
...newContent[index],
|
||||
tool_call: {
|
||||
...toolCall,
|
||||
output: '[OUTPUT_OMITTED_FOR_BREVITY]',
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const truncatedMessage = {
|
||||
...message,
|
||||
content: newContent,
|
||||
tokenCount: getTokenCountForMessage({ role: 'assistant', content: newContent }),
|
||||
};
|
||||
|
||||
processedMessages.push(truncatedMessage);
|
||||
}
|
||||
|
||||
return { dbMessages: processedMessages.reverse(), editedIndices: Array.from(editedIndices) };
|
||||
}
|
||||
|
||||
module.exports = { truncateText, smartTruncateText, truncateToolCallOutputs };
|
||||
module.exports = { truncateText, smartTruncateText };
|
||||
|
|
|
|||
|
|
@ -355,7 +355,8 @@ describe('BaseClient', () => {
|
|||
id: '3',
|
||||
parentMessageId: '2',
|
||||
role: 'system',
|
||||
text: 'Summary for Message 3',
|
||||
text: 'Message 3',
|
||||
content: [{ type: 'text', text: 'Summary for Message 3' }],
|
||||
summary: 'Summary for Message 3',
|
||||
},
|
||||
{ id: '4', parentMessageId: '3', text: 'Message 4' },
|
||||
|
|
@ -380,7 +381,8 @@ describe('BaseClient', () => {
|
|||
id: '4',
|
||||
parentMessageId: '3',
|
||||
role: 'system',
|
||||
text: 'Summary for Message 4',
|
||||
text: 'Message 4',
|
||||
content: [{ type: 'text', text: 'Summary for Message 4' }],
|
||||
summary: 'Summary for Message 4',
|
||||
},
|
||||
{ id: '5', parentMessageId: '4', text: 'Message 5' },
|
||||
|
|
@ -405,12 +407,123 @@ describe('BaseClient', () => {
|
|||
id: '4',
|
||||
parentMessageId: '3',
|
||||
role: 'system',
|
||||
text: 'Summary for Message 4',
|
||||
text: 'Message 4',
|
||||
content: [{ type: 'text', text: 'Summary for Message 4' }],
|
||||
summary: 'Summary for Message 4',
|
||||
},
|
||||
{ id: '5', parentMessageId: '4', text: 'Message 5' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should detect summary content block and use it over legacy fields (summary mode)', () => {
|
||||
const messagesWithContentBlock = [
|
||||
{ id: '3', parentMessageId: '2', text: 'Message 3' },
|
||||
{
|
||||
id: '2',
|
||||
parentMessageId: '1',
|
||||
text: 'Message 2',
|
||||
content: [
|
||||
{ type: 'text', text: 'Original text' },
|
||||
{ type: 'summary', text: 'Content block summary', tokenCount: 42 },
|
||||
],
|
||||
},
|
||||
{ id: '1', parentMessageId: null, text: 'Message 1' },
|
||||
];
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: messagesWithContentBlock,
|
||||
parentMessageId: '3',
|
||||
summary: true,
|
||||
});
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].role).toBe('system');
|
||||
expect(result[0].content).toEqual([{ type: 'text', text: 'Content block summary' }]);
|
||||
expect(result[0].tokenCount).toBe(42);
|
||||
});
|
||||
|
||||
it('should prefer content block summary over legacy summary field', () => {
|
||||
const messagesWithBoth = [
|
||||
{ id: '2', parentMessageId: '1', text: 'Message 2' },
|
||||
{
|
||||
id: '1',
|
||||
parentMessageId: null,
|
||||
text: 'Message 1',
|
||||
summary: 'Legacy summary',
|
||||
summaryTokenCount: 10,
|
||||
content: [{ type: 'summary', text: 'Content block summary', tokenCount: 20 }],
|
||||
},
|
||||
];
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: messagesWithBoth,
|
||||
parentMessageId: '2',
|
||||
summary: true,
|
||||
});
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].content).toEqual([{ type: 'text', text: 'Content block summary' }]);
|
||||
expect(result[0].tokenCount).toBe(20);
|
||||
});
|
||||
|
||||
it('should fallback to legacy summary when no content block exists', () => {
|
||||
const messagesWithLegacy = [
|
||||
{ id: '2', parentMessageId: '1', text: 'Message 2' },
|
||||
{
|
||||
id: '1',
|
||||
parentMessageId: null,
|
||||
text: 'Message 1',
|
||||
summary: 'Legacy summary only',
|
||||
summaryTokenCount: 15,
|
||||
},
|
||||
];
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: messagesWithLegacy,
|
||||
parentMessageId: '2',
|
||||
summary: true,
|
||||
});
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].content).toEqual([{ type: 'text', text: 'Legacy summary only' }]);
|
||||
expect(result[0].tokenCount).toBe(15);
|
||||
});
|
||||
});
|
||||
|
||||
describe('findSummaryContentBlock', () => {
|
||||
it('should find a summary block in the content array', () => {
|
||||
const message = {
|
||||
content: [
|
||||
{ type: 'text', text: 'some text' },
|
||||
{ type: 'summary', text: 'Summary of conversation', tokenCount: 50 },
|
||||
],
|
||||
};
|
||||
const result = TestClient.constructor.findSummaryContentBlock(message);
|
||||
expect(result).toBeTruthy();
|
||||
expect(result.text).toBe('Summary of conversation');
|
||||
expect(result.tokenCount).toBe(50);
|
||||
});
|
||||
|
||||
it('should return null when no summary block exists', () => {
|
||||
const message = {
|
||||
content: [
|
||||
{ type: 'text', text: 'some text' },
|
||||
{ type: 'tool_call', tool_call: {} },
|
||||
],
|
||||
};
|
||||
expect(TestClient.constructor.findSummaryContentBlock(message)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for string content', () => {
|
||||
const message = { content: 'just a string' };
|
||||
expect(TestClient.constructor.findSummaryContentBlock(message)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for missing content', () => {
|
||||
expect(TestClient.constructor.findSummaryContentBlock({})).toBeNull();
|
||||
expect(TestClient.constructor.findSummaryContentBlock(null)).toBeNull();
|
||||
});
|
||||
|
||||
it('should skip summary blocks with no text', () => {
|
||||
const message = {
|
||||
content: [{ type: 'summary', tokenCount: 10 }],
|
||||
};
|
||||
expect(TestClient.constructor.findSummaryContentBlock(message)).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessage', () => {
|
||||
|
|
|
|||
|
|
@ -51,6 +51,10 @@ class DALLE3 extends Tool {
|
|||
this.fileStrategy = fields.fileStrategy;
|
||||
/** @type {boolean} */
|
||||
this.isAgent = fields.isAgent;
|
||||
if (this.isAgent) {
|
||||
/** Ensures LangChain maps [content, artifact] tuple to ToolMessage fields instead of serializing it into content. */
|
||||
this.responseFormat = 'content_and_artifact';
|
||||
}
|
||||
if (fields.processFileURL) {
|
||||
/** @type {processFileURL} Necessary for output to contain all image metadata. */
|
||||
this.processFileURL = fields.processFileURL.bind(this);
|
||||
|
|
|
|||
|
|
@ -113,6 +113,10 @@ class FluxAPI extends Tool {
|
|||
|
||||
/** @type {boolean} **/
|
||||
this.isAgent = fields.isAgent;
|
||||
if (this.isAgent) {
|
||||
/** Ensures LangChain maps [content, artifact] tuple to ToolMessage fields instead of serializing it into content. */
|
||||
this.responseFormat = 'content_and_artifact';
|
||||
}
|
||||
this.returnMetadata = fields.returnMetadata ?? false;
|
||||
|
||||
if (fields.processFileURL) {
|
||||
|
|
@ -524,10 +528,40 @@ class FluxAPI extends Tool {
|
|||
return this.returnValue('No image data received from Flux API.');
|
||||
}
|
||||
|
||||
// Try saving the image locally
|
||||
const imageUrl = resultData.sample;
|
||||
const imageName = `img-${uuidv4()}.png`;
|
||||
|
||||
if (this.isAgent) {
|
||||
try {
|
||||
const fetchOptions = {};
|
||||
if (process.env.PROXY) {
|
||||
fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
const imageResponse = await fetch(imageUrl, fetchOptions);
|
||||
const arrayBuffer = await imageResponse.arrayBuffer();
|
||||
const base64 = Buffer.from(arrayBuffer).toString('base64');
|
||||
const content = [
|
||||
{
|
||||
type: ContentTypes.IMAGE_URL,
|
||||
image_url: {
|
||||
url: `data:image/png;base64,${base64}`,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const response = [
|
||||
{
|
||||
type: ContentTypes.TEXT,
|
||||
text: displayMessage,
|
||||
},
|
||||
];
|
||||
return [response, { content }];
|
||||
} catch (error) {
|
||||
logger.error('[FluxAPI] Error processing finetuned image for agent:', error);
|
||||
return this.returnValue(`Failed to process the finetuned image. ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
logger.debug('[FluxAPI] Saving finetuned image:', imageUrl);
|
||||
const result = await this.processFileURL({
|
||||
|
|
@ -541,12 +575,6 @@ class FluxAPI extends Tool {
|
|||
|
||||
logger.debug('[FluxAPI] Finetuned image saved to path:', result.filepath);
|
||||
|
||||
// Calculate cost based on endpoint
|
||||
const endpointKey = endpoint.includes('ultra')
|
||||
? 'FLUX_PRO_1_1_ULTRA_FINETUNED'
|
||||
: 'FLUX_PRO_FINETUNED';
|
||||
const cost = FluxAPI.PRICING[endpointKey] || 0;
|
||||
// Return the result based on returnMetadata flag
|
||||
this.result = this.returnMetadata ? result : this.wrapInMarkdown(result.filepath);
|
||||
return this.returnValue(this.result);
|
||||
} catch (error) {
|
||||
|
|
|
|||
|
|
@ -13,8 +13,7 @@ const {
|
|||
getTransactionsConfig,
|
||||
} = require('@librechat/api');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { spendTokens, getFiles } = require('~/models');
|
||||
|
||||
/**
|
||||
* Configure proxy support for Google APIs
|
||||
|
|
|
|||
|
|
@ -43,6 +43,10 @@ class StableDiffusionAPI extends Tool {
|
|||
this.returnMetadata = fields.returnMetadata ?? false;
|
||||
/** @type {boolean} */
|
||||
this.isAgent = fields.isAgent;
|
||||
if (this.isAgent) {
|
||||
/** Ensures LangChain maps [content, artifact] tuple to ToolMessage fields instead of serializing it into content. */
|
||||
this.responseFormat = 'content_and_artifact';
|
||||
}
|
||||
if (fields.uploadImageBuffer) {
|
||||
/** @type {uploadImageBuffer} Necessary for output to contain all image metadata. */
|
||||
this.uploadImageBuffer = fields.uploadImageBuffer.bind(this);
|
||||
|
|
@ -115,7 +119,7 @@ class StableDiffusionAPI extends Tool {
|
|||
generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
|
||||
} catch (error) {
|
||||
logger.error('[StableDiffusion] Error while generating image:', error);
|
||||
return 'Error making API request.';
|
||||
return this.returnValue('Error making API request.');
|
||||
}
|
||||
const image = generationResponse.data.images[0];
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
const DALLE3 = require('../DALLE3');
|
||||
const { ProxyAgent } = require('undici');
|
||||
|
||||
jest.mock('tiktoken');
|
||||
const processFileURL = jest.fn();
|
||||
|
||||
describe('DALLE3 Proxy Configuration', () => {
|
||||
|
|
|
|||
|
|
@ -14,15 +14,6 @@ jest.mock('@librechat/data-schemas', () => {
|
|||
};
|
||||
});
|
||||
|
||||
jest.mock('tiktoken', () => {
|
||||
return {
|
||||
encoding_for_model: jest.fn().mockReturnValue({
|
||||
encode: jest.fn(),
|
||||
decode: jest.fn(),
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const processFileURL = jest.fn();
|
||||
|
||||
const generate = jest.fn();
|
||||
|
|
|
|||
294
api/app/clients/tools/structured/specs/imageTools-agent.spec.js
Normal file
294
api/app/clients/tools/structured/specs/imageTools-agent.spec.js
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
/**
|
||||
* Regression tests for image tool agent mode — verifies that invoke() returns
|
||||
* a ToolMessage with base64 in artifact.content rather than serialized into content.
|
||||
*
|
||||
* Root cause: DALLE3/FluxAPI/StableDiffusion extend LangChain's Tool but did not
|
||||
* set responseFormat = 'content_and_artifact'. LangChain's invoke() would then
|
||||
* JSON.stringify the entire [content, artifact] tuple into ToolMessage.content,
|
||||
* dumping base64 into token counting and causing context exhaustion.
|
||||
*/
|
||||
|
||||
const axios = require('axios');
|
||||
const OpenAI = require('openai');
|
||||
const undici = require('undici');
|
||||
const fetch = require('node-fetch');
|
||||
const { ToolMessage } = require('@langchain/core/messages');
|
||||
const { ContentTypes } = require('librechat-data-provider');
|
||||
const StableDiffusionAPI = require('../StableDiffusion');
|
||||
const FluxAPI = require('../FluxAPI');
|
||||
const DALLE3 = require('../DALLE3');
|
||||
|
||||
jest.mock('axios');
|
||||
jest.mock('openai');
|
||||
jest.mock('node-fetch');
|
||||
jest.mock('undici', () => ({
|
||||
ProxyAgent: jest.fn(),
|
||||
fetch: jest.fn(),
|
||||
}));
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: { info: jest.fn(), warn: jest.fn(), debug: jest.fn(), error: jest.fn() },
|
||||
}));
|
||||
jest.mock('path', () => ({
|
||||
resolve: jest.fn(),
|
||||
join: jest.fn().mockReturnValue('/mock/path'),
|
||||
relative: jest.fn().mockReturnValue('relative/path'),
|
||||
extname: jest.fn().mockReturnValue('.png'),
|
||||
}));
|
||||
jest.mock('fs', () => ({
|
||||
existsSync: jest.fn().mockReturnValue(true),
|
||||
mkdirSync: jest.fn(),
|
||||
promises: { writeFile: jest.fn(), readFile: jest.fn(), unlink: jest.fn() },
|
||||
}));
|
||||
|
||||
const FAKE_BASE64 = 'aGVsbG8=';
|
||||
|
||||
const makeToolCall = (name, args) => ({
|
||||
id: 'call_test_123',
|
||||
name,
|
||||
args,
|
||||
type: 'tool_call',
|
||||
});
|
||||
|
||||
describe('image tools - agent mode ToolMessage format', () => {
|
||||
const ENV_KEYS = ['DALLE_API_KEY', 'FLUX_API_KEY', 'SD_WEBUI_URL', 'PROXY'];
|
||||
let savedEnv = {};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
for (const key of ENV_KEYS) {
|
||||
savedEnv[key] = process.env[key];
|
||||
}
|
||||
process.env.DALLE_API_KEY = 'test-dalle-key';
|
||||
process.env.FLUX_API_KEY = 'test-flux-key';
|
||||
process.env.SD_WEBUI_URL = 'http://localhost:7860';
|
||||
delete process.env.PROXY;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
for (const key of ENV_KEYS) {
|
||||
if (savedEnv[key] === undefined) {
|
||||
delete process.env[key];
|
||||
} else {
|
||||
process.env[key] = savedEnv[key];
|
||||
}
|
||||
}
|
||||
savedEnv = {};
|
||||
});
|
||||
|
||||
describe('DALLE3', () => {
|
||||
beforeEach(() => {
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: {
|
||||
generate: jest.fn().mockResolvedValue({
|
||||
data: [{ url: 'https://example.com/image.png' }],
|
||||
}),
|
||||
},
|
||||
}));
|
||||
undici.fetch.mockResolvedValue({
|
||||
arrayBuffer: () => Promise.resolve(Buffer.from(FAKE_BASE64, 'base64')),
|
||||
});
|
||||
});
|
||||
|
||||
it('sets responseFormat to content_and_artifact when isAgent is true', () => {
|
||||
const dalle = new DALLE3({ isAgent: true });
|
||||
expect(dalle.responseFormat).toBe('content_and_artifact');
|
||||
});
|
||||
|
||||
it('does not set responseFormat when isAgent is false', () => {
|
||||
const dalle = new DALLE3({ isAgent: false, processFileURL: jest.fn() });
|
||||
expect(dalle.responseFormat).not.toBe('content_and_artifact');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with base64 in artifact, not serialized in content', async () => {
|
||||
const dalle = new DALLE3({ isAgent: true });
|
||||
const result = await dalle.invoke(
|
||||
makeToolCall('dalle', {
|
||||
prompt: 'a box',
|
||||
quality: 'standard',
|
||||
size: '1024x1024',
|
||||
style: 'vivid',
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).not.toContain(FAKE_BASE64);
|
||||
|
||||
expect(result.artifact).toBeDefined();
|
||||
const artifactContent = result.artifact?.content;
|
||||
expect(Array.isArray(artifactContent)).toBe(true);
|
||||
expect(artifactContent[0].type).toBe(ContentTypes.IMAGE_URL);
|
||||
expect(artifactContent[0].image_url.url).toContain('base64');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with error string in content when API fails', async () => {
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: { generate: jest.fn().mockRejectedValue(new Error('API error')) },
|
||||
}));
|
||||
|
||||
const dalle = new DALLE3({ isAgent: true });
|
||||
const result = await dalle.invoke(
|
||||
makeToolCall('dalle', {
|
||||
prompt: 'a box',
|
||||
quality: 'standard',
|
||||
size: '1024x1024',
|
||||
style: 'vivid',
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).toContain('Something went wrong');
|
||||
expect(result.artifact).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('FluxAPI', () => {
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers();
|
||||
axios.post.mockResolvedValue({ data: { id: 'task-123' } });
|
||||
axios.get.mockResolvedValue({
|
||||
data: { status: 'Ready', result: { sample: 'https://example.com/image.png' } },
|
||||
});
|
||||
fetch.mockResolvedValue({
|
||||
arrayBuffer: () => Promise.resolve(Buffer.from(FAKE_BASE64, 'base64')),
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('sets responseFormat to content_and_artifact when isAgent is true', () => {
|
||||
const flux = new FluxAPI({ isAgent: true });
|
||||
expect(flux.responseFormat).toBe('content_and_artifact');
|
||||
});
|
||||
|
||||
it('does not set responseFormat when isAgent is false', () => {
|
||||
const flux = new FluxAPI({ isAgent: false, processFileURL: jest.fn() });
|
||||
expect(flux.responseFormat).not.toBe('content_and_artifact');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with base64 in artifact, not serialized in content', async () => {
|
||||
const flux = new FluxAPI({ isAgent: true });
|
||||
const invokePromise = flux.invoke(
|
||||
makeToolCall('flux', { prompt: 'a box', endpoint: '/v1/flux-dev' }),
|
||||
);
|
||||
await jest.runAllTimersAsync();
|
||||
const result = await invokePromise;
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).not.toContain(FAKE_BASE64);
|
||||
|
||||
expect(result.artifact).toBeDefined();
|
||||
const artifactContent = result.artifact?.content;
|
||||
expect(Array.isArray(artifactContent)).toBe(true);
|
||||
expect(artifactContent[0].type).toBe(ContentTypes.IMAGE_URL);
|
||||
expect(artifactContent[0].image_url.url).toContain('base64');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with base64 in artifact for generate_finetuned action', async () => {
|
||||
const flux = new FluxAPI({ isAgent: true });
|
||||
const invokePromise = flux.invoke(
|
||||
makeToolCall('flux', {
|
||||
action: 'generate_finetuned',
|
||||
prompt: 'a box',
|
||||
finetune_id: 'ft-abc123',
|
||||
endpoint: '/v1/flux-pro-finetuned',
|
||||
}),
|
||||
);
|
||||
await jest.runAllTimersAsync();
|
||||
const result = await invokePromise;
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).not.toContain(FAKE_BASE64);
|
||||
|
||||
expect(result.artifact).toBeDefined();
|
||||
const artifactContent = result.artifact?.content;
|
||||
expect(Array.isArray(artifactContent)).toBe(true);
|
||||
expect(artifactContent[0].type).toBe(ContentTypes.IMAGE_URL);
|
||||
expect(artifactContent[0].image_url.url).toContain('base64');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with error string in content when task submission fails', async () => {
|
||||
axios.post.mockRejectedValue(new Error('Network error'));
|
||||
|
||||
const flux = new FluxAPI({ isAgent: true });
|
||||
const invokePromise = flux.invoke(
|
||||
makeToolCall('flux', { prompt: 'a box', endpoint: '/v1/flux-dev' }),
|
||||
);
|
||||
await jest.runAllTimersAsync();
|
||||
const result = await invokePromise;
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).toContain('Something went wrong');
|
||||
expect(result.artifact).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('StableDiffusion', () => {
|
||||
beforeEach(() => {
|
||||
axios.post.mockResolvedValue({
|
||||
data: {
|
||||
images: [FAKE_BASE64],
|
||||
info: JSON.stringify({ height: 1024, width: 1024, seed: 42, infotexts: [] }),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('sets responseFormat to content_and_artifact when isAgent is true', () => {
|
||||
const sd = new StableDiffusionAPI({ isAgent: true, override: true });
|
||||
expect(sd.responseFormat).toBe('content_and_artifact');
|
||||
});
|
||||
|
||||
it('does not set responseFormat when isAgent is false', () => {
|
||||
const sd = new StableDiffusionAPI({
|
||||
isAgent: false,
|
||||
override: true,
|
||||
uploadImageBuffer: jest.fn(),
|
||||
});
|
||||
expect(sd.responseFormat).not.toBe('content_and_artifact');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with base64 in artifact, not serialized in content', async () => {
|
||||
const sd = new StableDiffusionAPI({ isAgent: true, override: true, userId: 'user-1' });
|
||||
const result = await sd.invoke(
|
||||
makeToolCall('stable-diffusion', { prompt: 'a box', negative_prompt: '' }),
|
||||
);
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).not.toContain(FAKE_BASE64);
|
||||
|
||||
expect(result.artifact).toBeDefined();
|
||||
const artifactContent = result.artifact?.content;
|
||||
expect(Array.isArray(artifactContent)).toBe(true);
|
||||
expect(artifactContent[0].type).toBe(ContentTypes.IMAGE_URL);
|
||||
expect(artifactContent[0].image_url.url).toContain('base64');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with error string in content when API fails', async () => {
|
||||
axios.post.mockRejectedValue(new Error('Connection refused'));
|
||||
|
||||
const sd = new StableDiffusionAPI({ isAgent: true, override: true, userId: 'user-1' });
|
||||
const result = await sd.invoke(
|
||||
makeToolCall('stable-diffusion', { prompt: 'a box', negative_prompt: '' }),
|
||||
);
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).toContain('Error making API request');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -45,7 +45,7 @@ const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
|||
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { getMCPServerTools } = require('~/server/services/Config');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { getRoleByName } = require('~/models');
|
||||
|
||||
/**
|
||||
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
|
||||
|
|
|
|||
3
api/cache/banViolation.js
vendored
3
api/cache/banViolation.js
vendored
|
|
@ -1,8 +1,7 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled, math } = require('@librechat/api');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, math, removePorts } = require('@librechat/api');
|
||||
const { deleteAllUserSessions } = require('~/models');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const getLogStores = require('./getLogStores');
|
||||
|
||||
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};
|
||||
|
|
|
|||
|
|
@ -236,8 +236,12 @@ async function performSync(flowManager, flowId, flowType) {
|
|||
const messageCount = messageProgress.totalDocuments;
|
||||
const messagesIndexed = messageProgress.totalProcessed;
|
||||
const unindexedMessages = messageCount - messagesIndexed;
|
||||
const noneIndexed = messagesIndexed === 0 && unindexedMessages > 0;
|
||||
|
||||
if (settingsUpdated || unindexedMessages > syncThreshold) {
|
||||
if (settingsUpdated || noneIndexed || unindexedMessages > syncThreshold) {
|
||||
if (noneIndexed && !settingsUpdated) {
|
||||
logger.info('[indexSync] No messages marked as indexed, forcing full sync');
|
||||
}
|
||||
logger.info(`[indexSync] Starting message sync (${unindexedMessages} unindexed)`);
|
||||
await Message.syncWithMeili();
|
||||
messagesSync = true;
|
||||
|
|
@ -261,9 +265,13 @@ async function performSync(flowManager, flowId, flowType) {
|
|||
|
||||
const convoCount = convoProgress.totalDocuments;
|
||||
const convosIndexed = convoProgress.totalProcessed;
|
||||
|
||||
const unindexedConvos = convoCount - convosIndexed;
|
||||
if (settingsUpdated || unindexedConvos > syncThreshold) {
|
||||
const noneConvosIndexed = convosIndexed === 0 && unindexedConvos > 0;
|
||||
|
||||
if (settingsUpdated || noneConvosIndexed || unindexedConvos > syncThreshold) {
|
||||
if (noneConvosIndexed && !settingsUpdated) {
|
||||
logger.info('[indexSync] No conversations marked as indexed, forcing full sync');
|
||||
}
|
||||
logger.info(`[indexSync] Starting convos sync (${unindexedConvos} unindexed)`);
|
||||
await Conversation.syncWithMeili();
|
||||
convosSync = true;
|
||||
|
|
|
|||
|
|
@ -462,4 +462,69 @@ describe('performSync() - syncThreshold logic', () => {
|
|||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (50 unindexed)');
|
||||
});
|
||||
|
||||
test('forces sync when zero documents indexed (reset scenario) even if below threshold', async () => {
|
||||
Message.getSyncProgress.mockResolvedValue({
|
||||
totalProcessed: 0,
|
||||
totalDocuments: 680,
|
||||
isComplete: false,
|
||||
});
|
||||
|
||||
Conversation.getSyncProgress.mockResolvedValue({
|
||||
totalProcessed: 0,
|
||||
totalDocuments: 76,
|
||||
isComplete: false,
|
||||
});
|
||||
|
||||
Message.syncWithMeili.mockResolvedValue(undefined);
|
||||
Conversation.syncWithMeili.mockResolvedValue(undefined);
|
||||
|
||||
const indexSync = require('./indexSync');
|
||||
await indexSync();
|
||||
|
||||
expect(Message.syncWithMeili).toHaveBeenCalledTimes(1);
|
||||
expect(Conversation.syncWithMeili).toHaveBeenCalledTimes(1);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
'[indexSync] No messages marked as indexed, forcing full sync',
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
'[indexSync] Starting message sync (680 unindexed)',
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
'[indexSync] No conversations marked as indexed, forcing full sync',
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (76 unindexed)');
|
||||
});
|
||||
|
||||
test('does NOT force sync when some documents already indexed and below threshold', async () => {
|
||||
Message.getSyncProgress.mockResolvedValue({
|
||||
totalProcessed: 630,
|
||||
totalDocuments: 680,
|
||||
isComplete: false,
|
||||
});
|
||||
|
||||
Conversation.getSyncProgress.mockResolvedValue({
|
||||
totalProcessed: 70,
|
||||
totalDocuments: 76,
|
||||
isComplete: false,
|
||||
});
|
||||
|
||||
const indexSync = require('./indexSync');
|
||||
await indexSync();
|
||||
|
||||
expect(Message.syncWithMeili).not.toHaveBeenCalled();
|
||||
expect(Conversation.syncWithMeili).not.toHaveBeenCalled();
|
||||
expect(mockLogger.info).not.toHaveBeenCalledWith(
|
||||
'[indexSync] No messages marked as indexed, forcing full sync',
|
||||
);
|
||||
expect(mockLogger.info).not.toHaveBeenCalledWith(
|
||||
'[indexSync] No conversations marked as indexed, forcing full sync',
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
'[indexSync] 50 messages unindexed (below threshold: 1000, skipping)',
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
'[indexSync] 6 convos unindexed (below threshold: 1000, skipping)',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ module.exports = {
|
|||
moduleNameMapper: {
|
||||
'~/(.*)': '<rootDir>/$1',
|
||||
'~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json',
|
||||
'^openid-client/passport$': '<rootDir>/test/__mocks__/openid-client-passport.js', // Mock for the passport strategy part
|
||||
'^openid-client/passport$': '<rootDir>/test/__mocks__/openid-client-passport.js',
|
||||
'^openid-client$': '<rootDir>/test/__mocks__/openid-client.js',
|
||||
},
|
||||
transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'],
|
||||
|
|
|
|||
|
|
@ -1,77 +0,0 @@
|
|||
const { Action } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* Update an action with new data without overwriting existing properties,
|
||||
* or create a new action if it doesn't exist.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the action to update.
|
||||
* @param {string} searchParams.action_id - The ID of the action to update.
|
||||
* @param {string} searchParams.user - The user ID of the action's author.
|
||||
* @param {Object} updateData - An object containing the properties to update.
|
||||
* @returns {Promise<Action>} The updated or newly created action document as a plain object.
|
||||
*/
|
||||
const updateAction = async (searchParams, updateData) => {
|
||||
const options = { new: true, upsert: true };
|
||||
return await Action.findOneAndUpdate(searchParams, updateData, options).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves all actions that match the given search parameters.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find matching actions.
|
||||
* @param {boolean} includeSensitive - Flag to include sensitive data in the metadata.
|
||||
* @returns {Promise<Array<Action>>} A promise that resolves to an array of action documents as plain objects.
|
||||
*/
|
||||
const getActions = async (searchParams, includeSensitive = false) => {
|
||||
const actions = await Action.find(searchParams).lean();
|
||||
|
||||
if (!includeSensitive) {
|
||||
for (let i = 0; i < actions.length; i++) {
|
||||
const metadata = actions[i].metadata;
|
||||
if (!metadata) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret'];
|
||||
for (let field of sensitiveFields) {
|
||||
if (metadata[field]) {
|
||||
delete metadata[field];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return actions;
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes an action by params.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the action to delete.
|
||||
* @param {string} searchParams.action_id - The ID of the action to delete.
|
||||
* @param {string} searchParams.user - The user ID of the action's author.
|
||||
* @returns {Promise<Action>} A promise that resolves to the deleted action document as a plain object, or null if no document was found.
|
||||
*/
|
||||
const deleteAction = async (searchParams) => {
|
||||
return await Action.findOneAndDelete(searchParams).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes actions by params.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the actions to delete.
|
||||
* @param {string} searchParams.action_id - The ID of the action(s) to delete.
|
||||
* @param {string} searchParams.user - The user ID of the action's author.
|
||||
* @returns {Promise<Number>} A promise that resolves to the number of deleted action documents.
|
||||
*/
|
||||
const deleteActions = async (searchParams) => {
|
||||
const result = await Action.deleteMany(searchParams);
|
||||
return result.deletedCount;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getActions,
|
||||
updateAction,
|
||||
deleteAction,
|
||||
deleteActions,
|
||||
};
|
||||
|
|
@ -1,931 +0,0 @@
|
|||
const mongoose = require('mongoose');
|
||||
const crypto = require('node:crypto');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getCustomEndpointConfig } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
actionDelimiter,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
encodeEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const { mcp_all, mcp_delimiter } = require('librechat-data-provider').Constants;
|
||||
const {
|
||||
removeAgentFromAllProjects,
|
||||
removeAgentIdsFromProject,
|
||||
addAgentIdsToProject,
|
||||
} = require('./Project');
|
||||
const { removeAllPermissions } = require('~/server/services/PermissionService');
|
||||
const { getMCPServerTools } = require('~/server/services/Config');
|
||||
const { Agent, AclEntry, User } = require('~/db/models');
|
||||
const { getActions } = require('./Action');
|
||||
|
||||
/**
|
||||
* Extracts unique MCP server names from tools array
|
||||
* Tools format: "toolName_mcp_serverName" or "sys__server__sys_mcp_serverName"
|
||||
* @param {string[]} tools - Array of tool identifiers
|
||||
* @returns {string[]} Array of unique MCP server names
|
||||
*/
|
||||
const extractMCPServerNames = (tools) => {
|
||||
if (!tools || !Array.isArray(tools)) {
|
||||
return [];
|
||||
}
|
||||
const serverNames = new Set();
|
||||
for (const tool of tools) {
|
||||
if (!tool || !tool.includes(mcp_delimiter)) {
|
||||
continue;
|
||||
}
|
||||
const parts = tool.split(mcp_delimiter);
|
||||
if (parts.length >= 2) {
|
||||
serverNames.add(parts[parts.length - 1]);
|
||||
}
|
||||
}
|
||||
return Array.from(serverNames);
|
||||
};
|
||||
|
||||
/**
|
||||
* Create an agent with the provided data.
|
||||
* @param {Object} agentData - The agent data to create.
|
||||
* @returns {Promise<Agent>} The created agent document as a plain object.
|
||||
* @throws {Error} If the agent creation fails.
|
||||
*/
|
||||
const createAgent = async (agentData) => {
|
||||
const { author: _author, ...versionData } = agentData;
|
||||
const timestamp = new Date();
|
||||
const initialAgentData = {
|
||||
...agentData,
|
||||
versions: [
|
||||
{
|
||||
...versionData,
|
||||
createdAt: timestamp,
|
||||
updatedAt: timestamp,
|
||||
},
|
||||
],
|
||||
category: agentData.category || 'general',
|
||||
mcpServerNames: extractMCPServerNames(agentData.tools),
|
||||
};
|
||||
|
||||
return (await Agent.create(initialAgentData)).toObject();
|
||||
};
|
||||
|
||||
/**
|
||||
* Get an agent document based on the provided ID.
|
||||
*
|
||||
* @param {Object} searchParameter - The search parameters to find the agent to update.
|
||||
* @param {string} searchParameter.id - The ID of the agent to update.
|
||||
* @param {string} searchParameter.author - The user ID of the agent's author.
|
||||
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found.
|
||||
*/
|
||||
const getAgent = async (searchParameter) => await Agent.findOne(searchParameter).lean();
|
||||
|
||||
/**
|
||||
* Get multiple agent documents based on the provided search parameters.
|
||||
*
|
||||
* @param {Object} searchParameter - The search parameters to find agents.
|
||||
* @returns {Promise<Agent[]>} Array of agent documents as plain objects.
|
||||
*/
|
||||
const getAgents = async (searchParameter) => await Agent.find(searchParameter).lean();
|
||||
|
||||
/**
|
||||
* Load an agent based on the provided ID
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {string} params.spec
|
||||
* @param {string} params.agent_id
|
||||
* @param {string} params.endpoint
|
||||
* @param {import('@librechat/agents').ClientOptions} [params.model_parameters]
|
||||
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found.
|
||||
*/
|
||||
const loadEphemeralAgent = async ({ req, spec, endpoint, model_parameters: _m }) => {
|
||||
const { model, ...model_parameters } = _m;
|
||||
const modelSpecs = req.config?.modelSpecs?.list;
|
||||
/** @type {TModelSpec | null} */
|
||||
let modelSpec = null;
|
||||
if (spec != null && spec !== '') {
|
||||
modelSpec = modelSpecs?.find((s) => s.name === spec) || null;
|
||||
}
|
||||
/** @type {TEphemeralAgent | null} */
|
||||
const ephemeralAgent = req.body.ephemeralAgent;
|
||||
const mcpServers = new Set(ephemeralAgent?.mcp);
|
||||
const userId = req.user?.id; // note: userId cannot be undefined at runtime
|
||||
if (modelSpec?.mcpServers) {
|
||||
for (const mcpServer of modelSpec.mcpServers) {
|
||||
mcpServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
/** @type {string[]} */
|
||||
const tools = [];
|
||||
if (ephemeralAgent?.execute_code === true || modelSpec?.executeCode === true) {
|
||||
tools.push(Tools.execute_code);
|
||||
}
|
||||
if (ephemeralAgent?.file_search === true || modelSpec?.fileSearch === true) {
|
||||
tools.push(Tools.file_search);
|
||||
}
|
||||
if (ephemeralAgent?.web_search === true || modelSpec?.webSearch === true) {
|
||||
tools.push(Tools.web_search);
|
||||
}
|
||||
|
||||
const addedServers = new Set();
|
||||
if (mcpServers.size > 0) {
|
||||
for (const mcpServer of mcpServers) {
|
||||
if (addedServers.has(mcpServer)) {
|
||||
continue;
|
||||
}
|
||||
const serverTools = await getMCPServerTools(userId, mcpServer);
|
||||
if (!serverTools) {
|
||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||
addedServers.add(mcpServer);
|
||||
continue;
|
||||
}
|
||||
tools.push(...Object.keys(serverTools));
|
||||
addedServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
|
||||
const instructions = req.body.promptPrefix;
|
||||
|
||||
// Get endpoint config for modelDisplayLabel fallback
|
||||
const appConfig = req.config;
|
||||
let endpointConfig = appConfig?.endpoints?.[endpoint];
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig });
|
||||
} catch (err) {
|
||||
logger.error('[loadEphemeralAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
// For ephemeral agents, use modelLabel if provided, then model spec's label,
|
||||
// then modelDisplayLabel from endpoint config, otherwise empty string to show model name
|
||||
const sender =
|
||||
model_parameters?.modelLabel ?? modelSpec?.label ?? endpointConfig?.modelDisplayLabel ?? '';
|
||||
|
||||
// Encode ephemeral agent ID with endpoint, model, and computed sender for display
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender });
|
||||
|
||||
const result = {
|
||||
id: ephemeralId,
|
||||
instructions,
|
||||
provider: endpoint,
|
||||
model_parameters,
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
|
||||
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
|
||||
result.artifacts = ephemeralAgent.artifacts;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
/**
|
||||
* Load an agent based on the provided ID
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {string} params.spec
|
||||
* @param {string} params.agent_id
|
||||
* @param {string} params.endpoint
|
||||
* @param {import('@librechat/agents').ClientOptions} [params.model_parameters]
|
||||
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found.
|
||||
*/
|
||||
const loadAgent = async ({ req, spec, agent_id, endpoint, model_parameters }) => {
|
||||
if (!agent_id) {
|
||||
return null;
|
||||
}
|
||||
if (isEphemeralAgentId(agent_id)) {
|
||||
return await loadEphemeralAgent({ req, spec, endpoint, model_parameters });
|
||||
}
|
||||
const agent = await getAgent({
|
||||
id: agent_id,
|
||||
});
|
||||
|
||||
if (!agent) {
|
||||
return null;
|
||||
}
|
||||
|
||||
agent.version = agent.versions ? agent.versions.length : 0;
|
||||
return agent;
|
||||
};
|
||||
|
||||
/**
|
||||
* Check if a version already exists in the versions array, excluding timestamp and author fields
|
||||
* @param {Object} updateData - The update data to compare
|
||||
* @param {Object} currentData - The current agent data
|
||||
* @param {Array} versions - The existing versions array
|
||||
* @param {string} [actionsHash] - Hash of current action metadata
|
||||
* @returns {Object|null} - The matching version if found, null otherwise
|
||||
*/
|
||||
const isDuplicateVersion = (updateData, currentData, versions, actionsHash = null) => {
|
||||
if (!versions || versions.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const excludeFields = [
|
||||
'_id',
|
||||
'id',
|
||||
'createdAt',
|
||||
'updatedAt',
|
||||
'author',
|
||||
'updatedBy',
|
||||
'created_at',
|
||||
'updated_at',
|
||||
'__v',
|
||||
'versions',
|
||||
'actionsHash', // Exclude actionsHash from direct comparison
|
||||
];
|
||||
|
||||
const { $push: _$push, $pull: _$pull, $addToSet: _$addToSet, ...directUpdates } = updateData;
|
||||
|
||||
if (Object.keys(directUpdates).length === 0 && !actionsHash) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const wouldBeVersion = { ...currentData, ...directUpdates };
|
||||
const lastVersion = versions[versions.length - 1];
|
||||
|
||||
if (actionsHash && lastVersion.actionsHash !== actionsHash) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const allFields = new Set([...Object.keys(wouldBeVersion), ...Object.keys(lastVersion)]);
|
||||
|
||||
const importantFields = Array.from(allFields).filter((field) => !excludeFields.includes(field));
|
||||
|
||||
let isMatch = true;
|
||||
for (const field of importantFields) {
|
||||
const wouldBeValue = wouldBeVersion[field];
|
||||
const lastVersionValue = lastVersion[field];
|
||||
|
||||
// Skip if both are undefined/null
|
||||
if (!wouldBeValue && !lastVersionValue) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle arrays
|
||||
if (Array.isArray(wouldBeValue) || Array.isArray(lastVersionValue)) {
|
||||
// Normalize: treat undefined/null as empty array for comparison
|
||||
let wouldBeArr;
|
||||
if (Array.isArray(wouldBeValue)) {
|
||||
wouldBeArr = wouldBeValue;
|
||||
} else if (wouldBeValue == null) {
|
||||
wouldBeArr = [];
|
||||
} else {
|
||||
wouldBeArr = [wouldBeValue];
|
||||
}
|
||||
|
||||
let lastVersionArr;
|
||||
if (Array.isArray(lastVersionValue)) {
|
||||
lastVersionArr = lastVersionValue;
|
||||
} else if (lastVersionValue == null) {
|
||||
lastVersionArr = [];
|
||||
} else {
|
||||
lastVersionArr = [lastVersionValue];
|
||||
}
|
||||
|
||||
if (wouldBeArr.length !== lastVersionArr.length) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
|
||||
// Special handling for projectIds (MongoDB ObjectIds)
|
||||
if (field === 'projectIds') {
|
||||
const wouldBeIds = wouldBeArr.map((id) => id.toString()).sort();
|
||||
const versionIds = lastVersionArr.map((id) => id.toString()).sort();
|
||||
|
||||
if (!wouldBeIds.every((id, i) => id === versionIds[i])) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Handle arrays of objects
|
||||
else if (
|
||||
wouldBeArr.length > 0 &&
|
||||
typeof wouldBeArr[0] === 'object' &&
|
||||
wouldBeArr[0] !== null
|
||||
) {
|
||||
const sortedWouldBe = [...wouldBeArr].map((item) => JSON.stringify(item)).sort();
|
||||
const sortedVersion = [...lastVersionArr].map((item) => JSON.stringify(item)).sort();
|
||||
|
||||
if (!sortedWouldBe.every((item, i) => item === sortedVersion[i])) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const sortedWouldBe = [...wouldBeArr].sort();
|
||||
const sortedVersion = [...lastVersionArr].sort();
|
||||
|
||||
if (!sortedWouldBe.every((item, i) => item === sortedVersion[i])) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle objects
|
||||
else if (typeof wouldBeValue === 'object' && wouldBeValue !== null) {
|
||||
const lastVersionObj =
|
||||
typeof lastVersionValue === 'object' && lastVersionValue !== null ? lastVersionValue : {};
|
||||
|
||||
// For empty objects, normalize the comparison
|
||||
const wouldBeKeys = Object.keys(wouldBeValue);
|
||||
const lastVersionKeys = Object.keys(lastVersionObj);
|
||||
|
||||
// If both are empty objects, they're equal
|
||||
if (wouldBeKeys.length === 0 && lastVersionKeys.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise do a deep comparison
|
||||
if (JSON.stringify(wouldBeValue) !== JSON.stringify(lastVersionObj)) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Handle primitive values
|
||||
else {
|
||||
// For primitives, handle the case where one is undefined and the other is a default value
|
||||
if (wouldBeValue !== lastVersionValue) {
|
||||
// Special handling for boolean false vs undefined
|
||||
if (
|
||||
typeof wouldBeValue === 'boolean' &&
|
||||
wouldBeValue === false &&
|
||||
lastVersionValue === undefined
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
// Special handling for empty string vs undefined
|
||||
if (
|
||||
typeof wouldBeValue === 'string' &&
|
||||
wouldBeValue === '' &&
|
||||
lastVersionValue === undefined
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return isMatch ? lastVersion : null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Update an agent with new data without overwriting existing
|
||||
* properties, or create a new agent if it doesn't exist.
|
||||
* When an agent is updated, a copy of the current state will be saved to the versions array.
|
||||
*
|
||||
* @param {Object} searchParameter - The search parameters to find the agent to update.
|
||||
* @param {string} searchParameter.id - The ID of the agent to update.
|
||||
* @param {string} [searchParameter.author] - The user ID of the agent's author.
|
||||
* @param {Object} updateData - An object containing the properties to update.
|
||||
* @param {Object} [options] - Optional configuration object.
|
||||
* @param {string} [options.updatingUserId] - The ID of the user performing the update (used for tracking non-author updates).
|
||||
* @param {boolean} [options.forceVersion] - Force creation of a new version even if no fields changed.
|
||||
* @param {boolean} [options.skipVersioning] - Skip version creation entirely (useful for isolated operations like sharing).
|
||||
* @returns {Promise<Agent>} The updated or newly created agent document as a plain object.
|
||||
* @throws {Error} If the update would create a duplicate version
|
||||
*/
|
||||
const updateAgent = async (searchParameter, updateData, options = {}) => {
|
||||
const { updatingUserId = null, forceVersion = false, skipVersioning = false } = options;
|
||||
const mongoOptions = { new: true, upsert: false };
|
||||
|
||||
const currentAgent = await Agent.findOne(searchParameter);
|
||||
if (currentAgent) {
|
||||
const {
|
||||
__v,
|
||||
_id,
|
||||
id: __id,
|
||||
versions,
|
||||
author: _author,
|
||||
...versionData
|
||||
} = currentAgent.toObject();
|
||||
const { $push, $pull, $addToSet, ...directUpdates } = updateData;
|
||||
|
||||
// Sync mcpServerNames when tools are updated
|
||||
if (directUpdates.tools !== undefined) {
|
||||
const mcpServerNames = extractMCPServerNames(directUpdates.tools);
|
||||
directUpdates.mcpServerNames = mcpServerNames;
|
||||
updateData.mcpServerNames = mcpServerNames; // Also update the original updateData
|
||||
}
|
||||
|
||||
let actionsHash = null;
|
||||
|
||||
// Generate actions hash if agent has actions
|
||||
if (currentAgent.actions && currentAgent.actions.length > 0) {
|
||||
// Extract action IDs from the format "domain_action_id"
|
||||
const actionIds = currentAgent.actions
|
||||
.map((action) => {
|
||||
const parts = action.split(actionDelimiter);
|
||||
return parts[1]; // Get just the action ID part
|
||||
})
|
||||
.filter(Boolean);
|
||||
|
||||
if (actionIds.length > 0) {
|
||||
try {
|
||||
const actions = await getActions(
|
||||
{
|
||||
action_id: { $in: actionIds },
|
||||
},
|
||||
true,
|
||||
); // Include sensitive data for hash
|
||||
|
||||
actionsHash = await generateActionMetadataHash(currentAgent.actions, actions);
|
||||
} catch (error) {
|
||||
logger.error('Error fetching actions for hash generation:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const shouldCreateVersion =
|
||||
!skipVersioning &&
|
||||
(forceVersion || Object.keys(directUpdates).length > 0 || $push || $pull || $addToSet);
|
||||
|
||||
if (shouldCreateVersion) {
|
||||
const duplicateVersion = isDuplicateVersion(updateData, versionData, versions, actionsHash);
|
||||
if (duplicateVersion && !forceVersion) {
|
||||
// No changes detected, return the current agent without creating a new version
|
||||
const agentObj = currentAgent.toObject();
|
||||
agentObj.version = versions.length;
|
||||
return agentObj;
|
||||
}
|
||||
}
|
||||
|
||||
const versionEntry = {
|
||||
...versionData,
|
||||
...directUpdates,
|
||||
updatedAt: new Date(),
|
||||
};
|
||||
|
||||
// Include actions hash in version if available
|
||||
if (actionsHash) {
|
||||
versionEntry.actionsHash = actionsHash;
|
||||
}
|
||||
|
||||
// Always store updatedBy field to track who made the change
|
||||
if (updatingUserId) {
|
||||
versionEntry.updatedBy = new mongoose.Types.ObjectId(updatingUserId);
|
||||
}
|
||||
|
||||
if (shouldCreateVersion) {
|
||||
updateData.$push = {
|
||||
...($push || {}),
|
||||
versions: versionEntry,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return Agent.findOneAndUpdate(searchParameter, updateData, mongoOptions).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Modifies an agent with the resource file id.
|
||||
* @param {object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {string} params.agent_id
|
||||
* @param {string} params.tool_resource
|
||||
* @param {string} params.file_id
|
||||
* @returns {Promise<Agent>} The updated agent.
|
||||
*/
|
||||
const addAgentResourceFile = async ({ req, agent_id, tool_resource, file_id }) => {
|
||||
const searchParameter = { id: agent_id };
|
||||
let agent = await getAgent(searchParameter);
|
||||
if (!agent) {
|
||||
throw new Error('Agent not found for adding resource file');
|
||||
}
|
||||
const fileIdsPath = `tool_resources.${tool_resource}.file_ids`;
|
||||
await Agent.updateOne(
|
||||
{
|
||||
id: agent_id,
|
||||
[`${fileIdsPath}`]: { $exists: false },
|
||||
},
|
||||
{
|
||||
$set: {
|
||||
[`${fileIdsPath}`]: [],
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const updateData = {
|
||||
$addToSet: {
|
||||
tools: tool_resource,
|
||||
[fileIdsPath]: file_id,
|
||||
},
|
||||
};
|
||||
|
||||
const updatedAgent = await updateAgent(searchParameter, updateData, {
|
||||
updatingUserId: req?.user?.id,
|
||||
});
|
||||
if (updatedAgent) {
|
||||
return updatedAgent;
|
||||
} else {
|
||||
throw new Error('Agent not found for adding resource file');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Removes multiple resource files from an agent using atomic operations.
|
||||
* @param {object} params
|
||||
* @param {string} params.agent_id
|
||||
* @param {Array<{tool_resource: string, file_id: string}>} params.files
|
||||
* @returns {Promise<Agent>} The updated agent.
|
||||
* @throws {Error} If the agent is not found or update fails.
|
||||
*/
|
||||
const removeAgentResourceFiles = async ({ agent_id, files }) => {
|
||||
const searchParameter = { id: agent_id };
|
||||
|
||||
// Group files to remove by resource
|
||||
const filesByResource = files.reduce((acc, { tool_resource, file_id }) => {
|
||||
if (!acc[tool_resource]) {
|
||||
acc[tool_resource] = [];
|
||||
}
|
||||
acc[tool_resource].push(file_id);
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
// Step 1: Atomically remove file IDs using $pull
|
||||
const pullOps = {};
|
||||
const resourcesToCheck = new Set();
|
||||
for (const [resource, fileIds] of Object.entries(filesByResource)) {
|
||||
const fileIdsPath = `tool_resources.${resource}.file_ids`;
|
||||
pullOps[fileIdsPath] = { $in: fileIds };
|
||||
resourcesToCheck.add(resource);
|
||||
}
|
||||
|
||||
const updatePullData = { $pull: pullOps };
|
||||
const agentAfterPull = await Agent.findOneAndUpdate(searchParameter, updatePullData, {
|
||||
new: true,
|
||||
}).lean();
|
||||
|
||||
if (!agentAfterPull) {
|
||||
// Agent might have been deleted concurrently, or never existed.
|
||||
// Check if it existed before trying to throw.
|
||||
const agentExists = await getAgent(searchParameter);
|
||||
if (!agentExists) {
|
||||
throw new Error('Agent not found for removing resource files');
|
||||
}
|
||||
// If it existed but findOneAndUpdate returned null, something else went wrong.
|
||||
throw new Error('Failed to update agent during file removal (pull step)');
|
||||
}
|
||||
|
||||
// Return the agent state directly after the $pull operation.
|
||||
// Skipping the $unset step for now to simplify and test core $pull atomicity.
|
||||
// Empty arrays might remain, but the removal itself should be correct.
|
||||
return agentAfterPull;
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes an agent based on the provided ID.
|
||||
*
|
||||
* @param {Object} searchParameter - The search parameters to find the agent to delete.
|
||||
* @param {string} searchParameter.id - The ID of the agent to delete.
|
||||
* @param {string} [searchParameter.author] - The user ID of the agent's author.
|
||||
* @returns {Promise<void>} Resolves when the agent has been successfully deleted.
|
||||
*/
|
||||
const deleteAgent = async (searchParameter) => {
|
||||
const agent = await Agent.findOneAndDelete(searchParameter);
|
||||
if (agent) {
|
||||
await removeAgentFromAllProjects(agent.id);
|
||||
await Promise.all([
|
||||
removeAllPermissions({
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
}),
|
||||
removeAllPermissions({
|
||||
resourceType: ResourceType.REMOTE_AGENT,
|
||||
resourceId: agent._id,
|
||||
}),
|
||||
]);
|
||||
try {
|
||||
await Agent.updateMany({ 'edges.to': agent.id }, { $pull: { edges: { to: agent.id } } });
|
||||
} catch (error) {
|
||||
logger.error('[deleteAgent] Error removing agent from handoff edges', error);
|
||||
}
|
||||
try {
|
||||
await User.updateMany(
|
||||
{ 'favorites.agentId': agent.id },
|
||||
{ $pull: { favorites: { agentId: agent.id } } },
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('[deleteAgent] Error removing agent from user favorites', error);
|
||||
}
|
||||
}
|
||||
return agent;
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes all agents created by a specific user.
|
||||
* @param {string} userId - The ID of the user whose agents should be deleted.
|
||||
* @returns {Promise<void>} A promise that resolves when all user agents have been deleted.
|
||||
*/
|
||||
const deleteUserAgents = async (userId) => {
|
||||
try {
|
||||
const userAgents = await getAgents({ author: userId });
|
||||
|
||||
if (userAgents.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const agentIds = userAgents.map((agent) => agent.id);
|
||||
const agentObjectIds = userAgents.map((agent) => agent._id);
|
||||
|
||||
for (const agentId of agentIds) {
|
||||
await removeAgentFromAllProjects(agentId);
|
||||
}
|
||||
|
||||
await AclEntry.deleteMany({
|
||||
resourceType: { $in: [ResourceType.AGENT, ResourceType.REMOTE_AGENT] },
|
||||
resourceId: { $in: agentObjectIds },
|
||||
});
|
||||
|
||||
try {
|
||||
await User.updateMany(
|
||||
{ 'favorites.agentId': { $in: agentIds } },
|
||||
{ $pull: { favorites: { agentId: { $in: agentIds } } } },
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('[deleteUserAgents] Error removing agents from user favorites', error);
|
||||
}
|
||||
|
||||
await Agent.deleteMany({ author: userId });
|
||||
} catch (error) {
|
||||
logger.error('[deleteUserAgents] General error:', error);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get agents by accessible IDs with optional cursor-based pagination.
|
||||
* @param {Object} params - The parameters for getting accessible agents.
|
||||
* @param {Array} [params.accessibleIds] - Array of agent ObjectIds the user has ACL access to.
|
||||
* @param {Object} [params.otherParams] - Additional query parameters (including author filter).
|
||||
* @param {number} [params.limit] - Number of agents to return (max 100). If not provided, returns all agents.
|
||||
* @param {string} [params.after] - Cursor for pagination - get agents after this cursor. // base64 encoded JSON string with updatedAt and _id.
|
||||
* @returns {Promise<Object>} A promise that resolves to an object containing the agents data and pagination info.
|
||||
*/
|
||||
const getListAgentsByAccess = async ({
|
||||
accessibleIds = [],
|
||||
otherParams = {},
|
||||
limit = null,
|
||||
after = null,
|
||||
}) => {
|
||||
const isPaginated = limit !== null && limit !== undefined;
|
||||
const normalizedLimit = isPaginated ? Math.min(Math.max(1, parseInt(limit) || 20), 100) : null;
|
||||
|
||||
// Build base query combining ACL accessible agents with other filters
|
||||
const baseQuery = { ...otherParams, _id: { $in: accessibleIds } };
|
||||
|
||||
// Add cursor condition
|
||||
if (after) {
|
||||
try {
|
||||
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
|
||||
const { updatedAt, _id } = cursor;
|
||||
|
||||
const cursorCondition = {
|
||||
$or: [
|
||||
{ updatedAt: { $lt: new Date(updatedAt) } },
|
||||
{ updatedAt: new Date(updatedAt), _id: { $gt: new mongoose.Types.ObjectId(_id) } },
|
||||
],
|
||||
};
|
||||
|
||||
// Merge cursor condition with base query
|
||||
if (Object.keys(baseQuery).length > 0) {
|
||||
baseQuery.$and = [{ ...baseQuery }, cursorCondition];
|
||||
// Remove the original conditions from baseQuery to avoid duplication
|
||||
Object.keys(baseQuery).forEach((key) => {
|
||||
if (key !== '$and') delete baseQuery[key];
|
||||
});
|
||||
} else {
|
||||
Object.assign(baseQuery, cursorCondition);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Invalid cursor:', error.message);
|
||||
}
|
||||
}
|
||||
|
||||
let query = Agent.find(baseQuery, {
|
||||
id: 1,
|
||||
_id: 1,
|
||||
name: 1,
|
||||
avatar: 1,
|
||||
author: 1,
|
||||
projectIds: 1,
|
||||
description: 1,
|
||||
updatedAt: 1,
|
||||
category: 1,
|
||||
support_contact: 1,
|
||||
is_promoted: 1,
|
||||
}).sort({ updatedAt: -1, _id: 1 });
|
||||
|
||||
// Only apply limit if pagination is requested
|
||||
if (isPaginated) {
|
||||
query = query.limit(normalizedLimit + 1);
|
||||
}
|
||||
|
||||
const agents = await query.lean();
|
||||
|
||||
const hasMore = isPaginated ? agents.length > normalizedLimit : false;
|
||||
const data = (isPaginated ? agents.slice(0, normalizedLimit) : agents).map((agent) => {
|
||||
if (agent.author) {
|
||||
agent.author = agent.author.toString();
|
||||
}
|
||||
return agent;
|
||||
});
|
||||
|
||||
// Generate next cursor only if paginated
|
||||
let nextCursor = null;
|
||||
if (isPaginated && hasMore && data.length > 0) {
|
||||
const lastAgent = agents[normalizedLimit - 1];
|
||||
nextCursor = Buffer.from(
|
||||
JSON.stringify({
|
||||
updatedAt: lastAgent.updatedAt.toISOString(),
|
||||
_id: lastAgent._id.toString(),
|
||||
}),
|
||||
).toString('base64');
|
||||
}
|
||||
|
||||
return {
|
||||
object: 'list',
|
||||
data,
|
||||
first_id: data.length > 0 ? data[0].id : null,
|
||||
last_id: data.length > 0 ? data[data.length - 1].id : null,
|
||||
has_more: hasMore,
|
||||
after: nextCursor,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates the projects associated with an agent, adding and removing project IDs as specified.
|
||||
* This function also updates the corresponding projects to include or exclude the agent ID.
|
||||
*
|
||||
* @param {Object} params - Parameters for updating the agent's projects.
|
||||
* @param {IUser} params.user - Parameters for updating the agent's projects.
|
||||
* @param {string} params.agentId - The ID of the agent to update.
|
||||
* @param {string[]} [params.projectIds] - Array of project IDs to add to the agent.
|
||||
* @param {string[]} [params.removeProjectIds] - Array of project IDs to remove from the agent.
|
||||
* @returns {Promise<MongoAgent>} The updated agent document.
|
||||
* @throws {Error} If there's an error updating the agent or projects.
|
||||
*/
|
||||
const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds }) => {
|
||||
const updateOps = {};
|
||||
|
||||
if (removeProjectIds && removeProjectIds.length > 0) {
|
||||
for (const projectId of removeProjectIds) {
|
||||
await removeAgentIdsFromProject(projectId, [agentId]);
|
||||
}
|
||||
updateOps.$pull = { projectIds: { $in: removeProjectIds } };
|
||||
}
|
||||
|
||||
if (projectIds && projectIds.length > 0) {
|
||||
for (const projectId of projectIds) {
|
||||
await addAgentIdsToProject(projectId, [agentId]);
|
||||
}
|
||||
updateOps.$addToSet = { projectIds: { $each: projectIds } };
|
||||
}
|
||||
|
||||
if (Object.keys(updateOps).length === 0) {
|
||||
return await getAgent({ id: agentId });
|
||||
}
|
||||
|
||||
const updateQuery = { id: agentId, author: user.id };
|
||||
if (user.role === SystemRoles.ADMIN) {
|
||||
delete updateQuery.author;
|
||||
}
|
||||
|
||||
const updatedAgent = await updateAgent(updateQuery, updateOps, {
|
||||
updatingUserId: user.id,
|
||||
skipVersioning: true,
|
||||
});
|
||||
if (updatedAgent) {
|
||||
return updatedAgent;
|
||||
}
|
||||
if (updateOps.$addToSet) {
|
||||
for (const projectId of projectIds) {
|
||||
await removeAgentIdsFromProject(projectId, [agentId]);
|
||||
}
|
||||
} else if (updateOps.$pull) {
|
||||
for (const projectId of removeProjectIds) {
|
||||
await addAgentIdsToProject(projectId, [agentId]);
|
||||
}
|
||||
}
|
||||
|
||||
return await getAgent({ id: agentId });
|
||||
};
|
||||
|
||||
/**
|
||||
* Reverts an agent to a specific version in its version history.
|
||||
* @param {Object} searchParameter - The search parameters to find the agent to revert.
|
||||
* @param {string} searchParameter.id - The ID of the agent to revert.
|
||||
* @param {string} [searchParameter.author] - The user ID of the agent's author.
|
||||
* @param {number} versionIndex - The index of the version to revert to in the versions array.
|
||||
* @returns {Promise<MongoAgent>} The updated agent document after reverting.
|
||||
* @throws {Error} If the agent is not found or the specified version does not exist.
|
||||
*/
|
||||
const revertAgentVersion = async (searchParameter, versionIndex) => {
|
||||
const agent = await Agent.findOne(searchParameter);
|
||||
if (!agent) {
|
||||
throw new Error('Agent not found');
|
||||
}
|
||||
|
||||
if (!agent.versions || !agent.versions[versionIndex]) {
|
||||
throw new Error(`Version ${versionIndex} not found`);
|
||||
}
|
||||
|
||||
const revertToVersion = agent.versions[versionIndex];
|
||||
|
||||
const updateData = {
|
||||
...revertToVersion,
|
||||
};
|
||||
|
||||
delete updateData._id;
|
||||
delete updateData.id;
|
||||
delete updateData.versions;
|
||||
delete updateData.author;
|
||||
delete updateData.updatedBy;
|
||||
|
||||
return Agent.findOneAndUpdate(searchParameter, updateData, { new: true }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates a hash of action metadata for version comparison
|
||||
* @param {string[]} actionIds - Array of action IDs in format "domain_action_id"
|
||||
* @param {Action[]} actions - Array of action documents
|
||||
* @returns {Promise<string>} - SHA256 hash of the action metadata
|
||||
*/
|
||||
const generateActionMetadataHash = async (actionIds, actions) => {
|
||||
if (!actionIds || actionIds.length === 0) {
|
||||
return '';
|
||||
}
|
||||
|
||||
// Create a map of action_id to metadata for quick lookup
|
||||
const actionMap = new Map();
|
||||
actions.forEach((action) => {
|
||||
actionMap.set(action.action_id, action.metadata);
|
||||
});
|
||||
|
||||
// Sort action IDs for consistent hashing
|
||||
const sortedActionIds = [...actionIds].sort();
|
||||
|
||||
// Build a deterministic string representation of all action metadata
|
||||
const metadataString = sortedActionIds
|
||||
.map((actionFullId) => {
|
||||
// Extract just the action_id part (after the delimiter)
|
||||
const parts = actionFullId.split(actionDelimiter);
|
||||
const actionId = parts[1];
|
||||
|
||||
const metadata = actionMap.get(actionId);
|
||||
if (!metadata) {
|
||||
return `${actionId}:null`;
|
||||
}
|
||||
|
||||
// Sort metadata keys for deterministic output
|
||||
const sortedKeys = Object.keys(metadata).sort();
|
||||
const metadataStr = sortedKeys
|
||||
.map((key) => `${key}:${JSON.stringify(metadata[key])}`)
|
||||
.join(',');
|
||||
return `${actionId}:{${metadataStr}}`;
|
||||
})
|
||||
.join(';');
|
||||
|
||||
// Use Web Crypto API to generate hash
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(metadataString);
|
||||
const hashBuffer = await crypto.webcrypto.subtle.digest('SHA-256', data);
|
||||
const hashArray = Array.from(new Uint8Array(hashBuffer));
|
||||
const hashHex = hashArray.map((b) => b.toString(16).padStart(2, '0')).join('');
|
||||
|
||||
return hashHex;
|
||||
};
|
||||
/**
|
||||
* Counts the number of promoted agents.
|
||||
* @returns {Promise<number>} - The count of promoted agents
|
||||
*/
|
||||
const countPromotedAgents = async () => {
|
||||
const count = await Agent.countDocuments({ is_promoted: true });
|
||||
return count;
|
||||
};
|
||||
|
||||
/**
|
||||
* Load a default agent based on the endpoint
|
||||
* @param {string} endpoint
|
||||
* @returns {Agent | null}
|
||||
*/
|
||||
|
||||
module.exports = {
|
||||
getAgent,
|
||||
getAgents,
|
||||
loadAgent,
|
||||
createAgent,
|
||||
updateAgent,
|
||||
deleteAgent,
|
||||
deleteUserAgents,
|
||||
revertAgentVersion,
|
||||
updateAgentProjects,
|
||||
countPromotedAgents,
|
||||
addAgentResourceFile,
|
||||
getListAgentsByAccess,
|
||||
removeAgentResourceFiles,
|
||||
generateActionMetadataHash,
|
||||
};
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
const { Assistant } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* Update an assistant with new data without overwriting existing properties,
|
||||
* or create a new assistant if it doesn't exist.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the assistant to update.
|
||||
* @param {string} searchParams.assistant_id - The ID of the assistant to update.
|
||||
* @param {string} searchParams.user - The user ID of the assistant's author.
|
||||
* @param {Object} updateData - An object containing the properties to update.
|
||||
* @returns {Promise<AssistantDocument>} The updated or newly created assistant document as a plain object.
|
||||
*/
|
||||
const updateAssistantDoc = async (searchParams, updateData) => {
|
||||
const options = { new: true, upsert: true };
|
||||
return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves an assistant document based on the provided ID.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the assistant to update.
|
||||
* @param {string} searchParams.assistant_id - The ID of the assistant to update.
|
||||
* @param {string} searchParams.user - The user ID of the assistant's author.
|
||||
* @returns {Promise<AssistantDocument|null>} The assistant document as a plain object, or null if not found.
|
||||
*/
|
||||
const getAssistant = async (searchParams) => await Assistant.findOne(searchParams).lean();
|
||||
|
||||
/**
|
||||
* Retrieves all assistants that match the given search parameters.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find matching assistants.
|
||||
* @param {Object} [select] - Optional. Specifies which document fields to include or exclude.
|
||||
* @returns {Promise<Array<AssistantDocument>>} A promise that resolves to an array of assistant documents as plain objects.
|
||||
*/
|
||||
const getAssistants = async (searchParams, select = null) => {
|
||||
let query = Assistant.find(searchParams);
|
||||
|
||||
if (select) {
|
||||
query = query.select(select);
|
||||
}
|
||||
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes an assistant based on the provided ID.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the assistant to delete.
|
||||
* @param {string} searchParams.assistant_id - The ID of the assistant to delete.
|
||||
* @param {string} searchParams.user - The user ID of the assistant's author.
|
||||
* @returns {Promise<void>} Resolves when the assistant has been successfully deleted.
|
||||
*/
|
||||
const deleteAssistant = async (searchParams) => {
|
||||
return await Assistant.findOneAndDelete(searchParams);
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
updateAssistantDoc,
|
||||
deleteAssistant,
|
||||
getAssistants,
|
||||
getAssistant,
|
||||
};
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Banner } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* Retrieves the current active banner.
|
||||
* @returns {Promise<Object|null>} The active banner object or null if no active banner is found.
|
||||
*/
|
||||
const getBanner = async (user) => {
|
||||
try {
|
||||
const now = new Date();
|
||||
const banner = await Banner.findOne({
|
||||
displayFrom: { $lte: now },
|
||||
$or: [{ displayTo: { $gte: now } }, { displayTo: null }],
|
||||
type: 'banner',
|
||||
}).lean();
|
||||
|
||||
if (!banner || banner.isPublic || user) {
|
||||
return banner;
|
||||
}
|
||||
|
||||
return null;
|
||||
} catch (error) {
|
||||
logger.error('[getBanners] Error getting banners', error);
|
||||
throw new Error('Error getting banners');
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = { getBanner };
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
|
||||
const options = [
|
||||
{
|
||||
label: 'com_ui_idea',
|
||||
value: 'idea',
|
||||
},
|
||||
{
|
||||
label: 'com_ui_travel',
|
||||
value: 'travel',
|
||||
},
|
||||
{
|
||||
label: 'com_ui_teach_or_explain',
|
||||
value: 'teach_or_explain',
|
||||
},
|
||||
{
|
||||
label: 'com_ui_write',
|
||||
value: 'write',
|
||||
},
|
||||
{
|
||||
label: 'com_ui_shop',
|
||||
value: 'shop',
|
||||
},
|
||||
{
|
||||
label: 'com_ui_code',
|
||||
value: 'code',
|
||||
},
|
||||
{
|
||||
label: 'com_ui_misc',
|
||||
value: 'misc',
|
||||
},
|
||||
{
|
||||
label: 'com_ui_roleplay',
|
||||
value: 'roleplay',
|
||||
},
|
||||
{
|
||||
label: 'com_ui_finance',
|
||||
value: 'finance',
|
||||
},
|
||||
];
|
||||
|
||||
module.exports = {
|
||||
/**
|
||||
* Retrieves the categories asynchronously.
|
||||
* @returns {Promise<TGetCategoriesResponse>} An array of category objects.
|
||||
* @throws {Error} If there is an error retrieving the categories.
|
||||
*/
|
||||
getCategories: async () => {
|
||||
try {
|
||||
// const categories = await Categories.find();
|
||||
return options;
|
||||
} catch (error) {
|
||||
logger.error('Error getting categories', error);
|
||||
return [];
|
||||
}
|
||||
},
|
||||
};
|
||||
|
|
@ -1,372 +0,0 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const { getMessages, deleteMessages } = require('./Message');
|
||||
const { Conversation } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* Searches for a conversation by conversationId and returns a lean document with only conversationId and user.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
* @returns {Promise<{conversationId: string, user: string} | null>} The conversation object with selected fields or null if not found.
|
||||
*/
|
||||
const searchConversation = async (conversationId) => {
|
||||
try {
|
||||
return await Conversation.findOne({ conversationId }, 'conversationId user').lean();
|
||||
} catch (error) {
|
||||
logger.error('[searchConversation] Error searching conversation', error);
|
||||
throw new Error('Error searching conversation');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves a single conversation for a given user and conversation ID.
|
||||
* @param {string} user - The user's ID.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
* @returns {Promise<TConversation>} The conversation object.
|
||||
*/
|
||||
const getConvo = async (user, conversationId) => {
|
||||
try {
|
||||
return await Conversation.findOne({ user, conversationId }).lean();
|
||||
} catch (error) {
|
||||
logger.error('[getConvo] Error getting single conversation', error);
|
||||
throw new Error('Error getting single conversation');
|
||||
}
|
||||
};
|
||||
|
||||
const deleteNullOrEmptyConversations = async () => {
|
||||
try {
|
||||
const filter = {
|
||||
$or: [
|
||||
{ conversationId: null },
|
||||
{ conversationId: '' },
|
||||
{ conversationId: { $exists: false } },
|
||||
],
|
||||
};
|
||||
|
||||
const result = await Conversation.deleteMany(filter);
|
||||
|
||||
// Delete associated messages
|
||||
const messageDeleteResult = await deleteMessages(filter);
|
||||
|
||||
logger.info(
|
||||
`[deleteNullOrEmptyConversations] Deleted ${result.deletedCount} conversations and ${messageDeleteResult.deletedCount} messages`,
|
||||
);
|
||||
|
||||
return {
|
||||
conversations: result,
|
||||
messages: messageDeleteResult,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[deleteNullOrEmptyConversations] Error deleting conversations', error);
|
||||
throw new Error('Error deleting conversations with null or empty conversationId');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Searches for a conversation by conversationId and returns associated file ids.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
* @returns {Promise<string[] | null>}
|
||||
*/
|
||||
const getConvoFiles = async (conversationId) => {
|
||||
try {
|
||||
return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? [];
|
||||
} catch (error) {
|
||||
logger.error('[getConvoFiles] Error getting conversation files', error);
|
||||
throw new Error('Error getting conversation files');
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getConvoFiles,
|
||||
searchConversation,
|
||||
deleteNullOrEmptyConversations,
|
||||
/**
|
||||
* Saves a conversation to the database.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
* @param {Object} metadata - Additional metadata to log for operation.
|
||||
* @returns {Promise<TConversation>} The conversation object.
|
||||
*/
|
||||
saveConvo: async (req, { conversationId, newConversationId, ...convo }, metadata) => {
|
||||
try {
|
||||
if (metadata?.context) {
|
||||
logger.debug(`[saveConvo] ${metadata.context}`);
|
||||
}
|
||||
|
||||
const messages = await getMessages({ conversationId }, '_id');
|
||||
const update = { ...convo, messages, user: req.user.id };
|
||||
|
||||
if (newConversationId) {
|
||||
update.conversationId = newConversationId;
|
||||
}
|
||||
|
||||
if (req?.body?.isTemporary) {
|
||||
try {
|
||||
const appConfig = req.config;
|
||||
update.expiredAt = createTempChatExpirationDate(appConfig?.interfaceConfig);
|
||||
} catch (err) {
|
||||
logger.error('Error creating temporary chat expiration date:', err);
|
||||
logger.info(`---\`saveConvo\` context: ${metadata?.context}`);
|
||||
update.expiredAt = null;
|
||||
}
|
||||
} else {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
/** @type {{ $set: Partial<TConversation>; $unset?: Record<keyof TConversation, number> }} */
|
||||
const updateOperation = { $set: update };
|
||||
if (metadata && metadata.unsetFields && Object.keys(metadata.unsetFields).length > 0) {
|
||||
updateOperation.$unset = metadata.unsetFields;
|
||||
}
|
||||
|
||||
/** Note: the resulting Model object is necessary for Meilisearch operations */
|
||||
const conversation = await Conversation.findOneAndUpdate(
|
||||
{ conversationId, user: req.user.id },
|
||||
updateOperation,
|
||||
{
|
||||
new: true,
|
||||
upsert: metadata?.noUpsert !== true,
|
||||
},
|
||||
);
|
||||
|
||||
if (!conversation) {
|
||||
logger.debug('[saveConvo] Conversation not found, skipping update');
|
||||
return null;
|
||||
}
|
||||
|
||||
return conversation.toObject();
|
||||
} catch (error) {
|
||||
logger.error('[saveConvo] Error saving conversation', error);
|
||||
if (metadata && metadata?.context) {
|
||||
logger.info(`[saveConvo] ${metadata.context}`);
|
||||
}
|
||||
return { message: 'Error saving conversation' };
|
||||
}
|
||||
},
|
||||
bulkSaveConvos: async (conversations) => {
|
||||
try {
|
||||
const bulkOps = conversations.map((convo) => ({
|
||||
updateOne: {
|
||||
filter: { conversationId: convo.conversationId, user: convo.user },
|
||||
update: convo,
|
||||
upsert: true,
|
||||
timestamps: false,
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await Conversation.bulkWrite(bulkOps);
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[bulkSaveConvos] Error saving conversations in bulk', error);
|
||||
throw new Error('Failed to save conversations in bulk.');
|
||||
}
|
||||
},
|
||||
getConvosByCursor: async (
|
||||
user,
|
||||
{
|
||||
cursor,
|
||||
limit = 25,
|
||||
isArchived = false,
|
||||
tags,
|
||||
search,
|
||||
sortBy = 'updatedAt',
|
||||
sortDirection = 'desc',
|
||||
} = {},
|
||||
) => {
|
||||
const filters = [{ user }];
|
||||
if (isArchived) {
|
||||
filters.push({ isArchived: true });
|
||||
} else {
|
||||
filters.push({ $or: [{ isArchived: false }, { isArchived: { $exists: false } }] });
|
||||
}
|
||||
|
||||
if (Array.isArray(tags) && tags.length > 0) {
|
||||
filters.push({ tags: { $in: tags } });
|
||||
}
|
||||
|
||||
filters.push({ $or: [{ expiredAt: null }, { expiredAt: { $exists: false } }] });
|
||||
|
||||
if (search) {
|
||||
try {
|
||||
const meiliResults = await Conversation.meiliSearch(search, { filter: `user = "${user}"` });
|
||||
const matchingIds = Array.isArray(meiliResults.hits)
|
||||
? meiliResults.hits.map((result) => result.conversationId)
|
||||
: [];
|
||||
if (!matchingIds.length) {
|
||||
return { conversations: [], nextCursor: null };
|
||||
}
|
||||
filters.push({ conversationId: { $in: matchingIds } });
|
||||
} catch (error) {
|
||||
logger.error('[getConvosByCursor] Error during meiliSearch', error);
|
||||
throw new Error('Error during meiliSearch');
|
||||
}
|
||||
}
|
||||
|
||||
const validSortFields = ['title', 'createdAt', 'updatedAt'];
|
||||
if (!validSortFields.includes(sortBy)) {
|
||||
throw new Error(
|
||||
`Invalid sortBy field: ${sortBy}. Must be one of ${validSortFields.join(', ')}`,
|
||||
);
|
||||
}
|
||||
const finalSortBy = sortBy;
|
||||
const finalSortDirection = sortDirection === 'asc' ? 'asc' : 'desc';
|
||||
|
||||
let cursorFilter = null;
|
||||
if (cursor) {
|
||||
try {
|
||||
const decoded = JSON.parse(Buffer.from(cursor, 'base64').toString());
|
||||
const { primary, secondary } = decoded;
|
||||
const primaryValue = finalSortBy === 'title' ? primary : new Date(primary);
|
||||
const secondaryValue = new Date(secondary);
|
||||
const op = finalSortDirection === 'asc' ? '$gt' : '$lt';
|
||||
|
||||
cursorFilter = {
|
||||
$or: [
|
||||
{ [finalSortBy]: { [op]: primaryValue } },
|
||||
{
|
||||
[finalSortBy]: primaryValue,
|
||||
updatedAt: { [op]: secondaryValue },
|
||||
},
|
||||
],
|
||||
};
|
||||
} catch (err) {
|
||||
logger.warn('[getConvosByCursor] Invalid cursor format, starting from beginning');
|
||||
}
|
||||
if (cursorFilter) {
|
||||
filters.push(cursorFilter);
|
||||
}
|
||||
}
|
||||
|
||||
const query = filters.length === 1 ? filters[0] : { $and: filters };
|
||||
|
||||
try {
|
||||
const sortOrder = finalSortDirection === 'asc' ? 1 : -1;
|
||||
const sortObj = { [finalSortBy]: sortOrder };
|
||||
|
||||
if (finalSortBy !== 'updatedAt') {
|
||||
sortObj.updatedAt = sortOrder;
|
||||
}
|
||||
|
||||
const convos = await Conversation.find(query)
|
||||
.select(
|
||||
'conversationId endpoint title createdAt updatedAt user model agent_id assistant_id spec iconURL',
|
||||
)
|
||||
.sort(sortObj)
|
||||
.limit(limit + 1)
|
||||
.lean();
|
||||
|
||||
let nextCursor = null;
|
||||
if (convos.length > limit) {
|
||||
convos.pop(); // Remove extra item used to detect next page
|
||||
// Create cursor from the last RETURNED item (not the popped one)
|
||||
const lastReturned = convos[convos.length - 1];
|
||||
const primaryValue = lastReturned[finalSortBy];
|
||||
const primaryStr = finalSortBy === 'title' ? primaryValue : primaryValue.toISOString();
|
||||
const secondaryStr = lastReturned.updatedAt.toISOString();
|
||||
const composite = { primary: primaryStr, secondary: secondaryStr };
|
||||
nextCursor = Buffer.from(JSON.stringify(composite)).toString('base64');
|
||||
}
|
||||
|
||||
return { conversations: convos, nextCursor };
|
||||
} catch (error) {
|
||||
logger.error('[getConvosByCursor] Error getting conversations', error);
|
||||
throw new Error('Error getting conversations');
|
||||
}
|
||||
},
|
||||
getConvosQueried: async (user, convoIds, cursor = null, limit = 25) => {
|
||||
try {
|
||||
if (!convoIds?.length) {
|
||||
return { conversations: [], nextCursor: null, convoMap: {} };
|
||||
}
|
||||
|
||||
const conversationIds = convoIds.map((convo) => convo.conversationId);
|
||||
|
||||
const results = await Conversation.find({
|
||||
user,
|
||||
conversationId: { $in: conversationIds },
|
||||
$or: [{ expiredAt: { $exists: false } }, { expiredAt: null }],
|
||||
}).lean();
|
||||
|
||||
results.sort((a, b) => new Date(b.updatedAt) - new Date(a.updatedAt));
|
||||
|
||||
let filtered = results;
|
||||
if (cursor && cursor !== 'start') {
|
||||
const cursorDate = new Date(cursor);
|
||||
filtered = results.filter((convo) => new Date(convo.updatedAt) < cursorDate);
|
||||
}
|
||||
|
||||
const limited = filtered.slice(0, limit + 1);
|
||||
let nextCursor = null;
|
||||
if (limited.length > limit) {
|
||||
limited.pop(); // Remove extra item used to detect next page
|
||||
// Create cursor from the last RETURNED item (not the popped one)
|
||||
nextCursor = limited[limited.length - 1].updatedAt.toISOString();
|
||||
}
|
||||
|
||||
const convoMap = {};
|
||||
limited.forEach((convo) => {
|
||||
convoMap[convo.conversationId] = convo;
|
||||
});
|
||||
|
||||
return { conversations: limited, nextCursor, convoMap };
|
||||
} catch (error) {
|
||||
logger.error('[getConvosQueried] Error getting conversations', error);
|
||||
throw new Error('Error fetching conversations');
|
||||
}
|
||||
},
|
||||
getConvo,
|
||||
/* chore: this method is not properly error handled */
|
||||
getConvoTitle: async (user, conversationId) => {
|
||||
try {
|
||||
const convo = await getConvo(user, conversationId);
|
||||
/* ChatGPT Browser was triggering error here due to convo being saved later */
|
||||
if (convo && !convo.title) {
|
||||
return null;
|
||||
} else {
|
||||
// TypeError: Cannot read properties of null (reading 'title')
|
||||
return convo?.title || 'New Chat';
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[getConvoTitle] Error getting conversation title', error);
|
||||
throw new Error('Error getting conversation title');
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Asynchronously deletes conversations and associated messages for a given user and filter.
|
||||
*
|
||||
* @async
|
||||
* @function
|
||||
* @param {string|ObjectId} user - The user's ID.
|
||||
* @param {Object} filter - Additional filter criteria for the conversations to be deleted.
|
||||
* @returns {Promise<{ n: number, ok: number, deletedCount: number, messages: { n: number, ok: number, deletedCount: number } }>}
|
||||
* An object containing the count of deleted conversations and associated messages.
|
||||
* @throws {Error} Throws an error if there's an issue with the database operations.
|
||||
*
|
||||
* @example
|
||||
* const user = 'someUserId';
|
||||
* const filter = { someField: 'someValue' };
|
||||
* const result = await deleteConvos(user, filter);
|
||||
* logger.error(result); // { n: 5, ok: 1, deletedCount: 5, messages: { n: 10, ok: 1, deletedCount: 10 } }
|
||||
*/
|
||||
deleteConvos: async (user, filter) => {
|
||||
try {
|
||||
const userFilter = { ...filter, user };
|
||||
const conversations = await Conversation.find(userFilter).select('conversationId');
|
||||
const conversationIds = conversations.map((c) => c.conversationId);
|
||||
|
||||
if (!conversationIds.length) {
|
||||
throw new Error('Conversation not found or already deleted.');
|
||||
}
|
||||
|
||||
const deleteConvoResult = await Conversation.deleteMany(userFilter);
|
||||
|
||||
const deleteMessagesResult = await deleteMessages({
|
||||
conversationId: { $in: conversationIds },
|
||||
});
|
||||
|
||||
return { ...deleteConvoResult, messages: deleteMessagesResult };
|
||||
} catch (error) {
|
||||
logger.error('[deleteConvos] Error deleting conversations and messages', error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
|
@ -1,284 +0,0 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ConversationTag, Conversation } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* Retrieves all conversation tags for a user.
|
||||
* @param {string} user - The user ID.
|
||||
* @returns {Promise<Array>} An array of conversation tags.
|
||||
*/
|
||||
const getConversationTags = async (user) => {
|
||||
try {
|
||||
return await ConversationTag.find({ user }).sort({ position: 1 }).lean();
|
||||
} catch (error) {
|
||||
logger.error('[getConversationTags] Error getting conversation tags', error);
|
||||
throw new Error('Error getting conversation tags');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a new conversation tag.
|
||||
* @param {string} user - The user ID.
|
||||
* @param {Object} data - The tag data.
|
||||
* @param {string} data.tag - The tag name.
|
||||
* @param {string} [data.description] - The tag description.
|
||||
* @param {boolean} [data.addToConversation] - Whether to add the tag to a conversation.
|
||||
* @param {string} [data.conversationId] - The conversation ID to add the tag to.
|
||||
* @returns {Promise<Object>} The created tag.
|
||||
*/
|
||||
const createConversationTag = async (user, data) => {
|
||||
try {
|
||||
const { tag, description, addToConversation, conversationId } = data;
|
||||
|
||||
const existingTag = await ConversationTag.findOne({ user, tag }).lean();
|
||||
if (existingTag) {
|
||||
return existingTag;
|
||||
}
|
||||
|
||||
const maxPosition = await ConversationTag.findOne({ user }).sort('-position').lean();
|
||||
const position = (maxPosition?.position || 0) + 1;
|
||||
|
||||
const newTag = await ConversationTag.findOneAndUpdate(
|
||||
{ tag, user },
|
||||
{
|
||||
tag,
|
||||
user,
|
||||
count: addToConversation ? 1 : 0,
|
||||
position,
|
||||
description,
|
||||
$setOnInsert: { createdAt: new Date() },
|
||||
},
|
||||
{
|
||||
new: true,
|
||||
upsert: true,
|
||||
lean: true,
|
||||
},
|
||||
);
|
||||
|
||||
if (addToConversation && conversationId) {
|
||||
await Conversation.findOneAndUpdate(
|
||||
{ user, conversationId },
|
||||
{ $addToSet: { tags: tag } },
|
||||
{ new: true },
|
||||
);
|
||||
}
|
||||
|
||||
return newTag;
|
||||
} catch (error) {
|
||||
logger.error('[createConversationTag] Error creating conversation tag', error);
|
||||
throw new Error('Error creating conversation tag');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates an existing conversation tag.
|
||||
* @param {string} user - The user ID.
|
||||
* @param {string} oldTag - The current tag name.
|
||||
* @param {Object} data - The updated tag data.
|
||||
* @param {string} [data.tag] - The new tag name.
|
||||
* @param {string} [data.description] - The updated description.
|
||||
* @param {number} [data.position] - The new position.
|
||||
* @returns {Promise<Object>} The updated tag.
|
||||
*/
|
||||
const updateConversationTag = async (user, oldTag, data) => {
|
||||
try {
|
||||
const { tag: newTag, description, position } = data;
|
||||
|
||||
const existingTag = await ConversationTag.findOne({ user, tag: oldTag }).lean();
|
||||
if (!existingTag) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (newTag && newTag !== oldTag) {
|
||||
const tagAlreadyExists = await ConversationTag.findOne({ user, tag: newTag }).lean();
|
||||
if (tagAlreadyExists) {
|
||||
throw new Error('Tag already exists');
|
||||
}
|
||||
|
||||
await Conversation.updateMany({ user, tags: oldTag }, { $set: { 'tags.$': newTag } });
|
||||
}
|
||||
|
||||
const updateData = {};
|
||||
if (newTag) {
|
||||
updateData.tag = newTag;
|
||||
}
|
||||
if (description !== undefined) {
|
||||
updateData.description = description;
|
||||
}
|
||||
if (position !== undefined) {
|
||||
await adjustPositions(user, existingTag.position, position);
|
||||
updateData.position = position;
|
||||
}
|
||||
|
||||
return await ConversationTag.findOneAndUpdate({ user, tag: oldTag }, updateData, {
|
||||
new: true,
|
||||
lean: true,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[updateConversationTag] Error updating conversation tag', error);
|
||||
throw new Error('Error updating conversation tag');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Adjusts positions of tags when a tag's position is changed.
|
||||
* @param {string} user - The user ID.
|
||||
* @param {number} oldPosition - The old position of the tag.
|
||||
* @param {number} newPosition - The new position of the tag.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const adjustPositions = async (user, oldPosition, newPosition) => {
|
||||
if (oldPosition === newPosition) {
|
||||
return;
|
||||
}
|
||||
|
||||
const update = oldPosition < newPosition ? { $inc: { position: -1 } } : { $inc: { position: 1 } };
|
||||
const position =
|
||||
oldPosition < newPosition
|
||||
? {
|
||||
$gt: Math.min(oldPosition, newPosition),
|
||||
$lte: Math.max(oldPosition, newPosition),
|
||||
}
|
||||
: {
|
||||
$gte: Math.min(oldPosition, newPosition),
|
||||
$lt: Math.max(oldPosition, newPosition),
|
||||
};
|
||||
|
||||
await ConversationTag.updateMany(
|
||||
{
|
||||
user,
|
||||
position,
|
||||
},
|
||||
update,
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes a conversation tag.
|
||||
* @param {string} user - The user ID.
|
||||
* @param {string} tag - The tag to delete.
|
||||
* @returns {Promise<Object>} The deleted tag.
|
||||
*/
|
||||
const deleteConversationTag = async (user, tag) => {
|
||||
try {
|
||||
const deletedTag = await ConversationTag.findOneAndDelete({ user, tag }).lean();
|
||||
if (!deletedTag) {
|
||||
return null;
|
||||
}
|
||||
|
||||
await Conversation.updateMany({ user, tags: tag }, { $pull: { tags: tag } });
|
||||
|
||||
await ConversationTag.updateMany(
|
||||
{ user, position: { $gt: deletedTag.position } },
|
||||
{ $inc: { position: -1 } },
|
||||
);
|
||||
|
||||
return deletedTag;
|
||||
} catch (error) {
|
||||
logger.error('[deleteConversationTag] Error deleting conversation tag', error);
|
||||
throw new Error('Error deleting conversation tag');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates tags for a specific conversation.
|
||||
* @param {string} user - The user ID.
|
||||
* @param {string} conversationId - The conversation ID.
|
||||
* @param {string[]} tags - The new set of tags for the conversation.
|
||||
* @returns {Promise<string[]>} The updated list of tags for the conversation.
|
||||
*/
|
||||
const updateTagsForConversation = async (user, conversationId, tags) => {
|
||||
try {
|
||||
const conversation = await Conversation.findOne({ user, conversationId }).lean();
|
||||
if (!conversation) {
|
||||
throw new Error('Conversation not found');
|
||||
}
|
||||
|
||||
const oldTags = new Set(conversation.tags);
|
||||
const newTags = new Set(tags);
|
||||
|
||||
const addedTags = [...newTags].filter((tag) => !oldTags.has(tag));
|
||||
const removedTags = [...oldTags].filter((tag) => !newTags.has(tag));
|
||||
|
||||
const bulkOps = [];
|
||||
|
||||
for (const tag of addedTags) {
|
||||
bulkOps.push({
|
||||
updateOne: {
|
||||
filter: { user, tag },
|
||||
update: { $inc: { count: 1 } },
|
||||
upsert: true,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
for (const tag of removedTags) {
|
||||
bulkOps.push({
|
||||
updateOne: {
|
||||
filter: { user, tag },
|
||||
update: { $inc: { count: -1 } },
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (bulkOps.length > 0) {
|
||||
await ConversationTag.bulkWrite(bulkOps);
|
||||
}
|
||||
|
||||
const updatedConversation = (
|
||||
await Conversation.findOneAndUpdate(
|
||||
{ user, conversationId },
|
||||
{ $set: { tags: [...newTags] } },
|
||||
{ new: true },
|
||||
)
|
||||
).toObject();
|
||||
|
||||
return updatedConversation.tags;
|
||||
} catch (error) {
|
||||
logger.error('[updateTagsForConversation] Error updating tags', error);
|
||||
throw new Error('Error updating tags for conversation');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Increments tag counts for existing tags only.
|
||||
* @param {string} user - The user ID.
|
||||
* @param {string[]} tags - Array of tag names to increment
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const bulkIncrementTagCounts = async (user, tags) => {
|
||||
if (!tags || tags.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const uniqueTags = [...new Set(tags.filter(Boolean))];
|
||||
if (uniqueTags.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bulkOps = uniqueTags.map((tag) => ({
|
||||
updateOne: {
|
||||
filter: { user, tag },
|
||||
update: { $inc: { count: 1 } },
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await ConversationTag.bulkWrite(bulkOps);
|
||||
if (result && result.modifiedCount > 0) {
|
||||
logger.debug(
|
||||
`user: ${user} | Incremented tag counts - modified ${result.modifiedCount} tags`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[bulkIncrementTagCounts] Error incrementing tag counts', error);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getConversationTags,
|
||||
createConversationTag,
|
||||
updateConversationTag,
|
||||
deleteConversationTag,
|
||||
bulkIncrementTagCounts,
|
||||
updateTagsForConversation,
|
||||
};
|
||||
|
|
@ -1,250 +0,0 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { EToolResources, FileContext } = require('librechat-data-provider');
|
||||
const { File } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* Finds a file by its file_id with additional query options.
|
||||
* @param {string} file_id - The unique identifier of the file.
|
||||
* @param {object} options - Query options for filtering, projection, etc.
|
||||
* @returns {Promise<MongoFile>} A promise that resolves to the file document or null.
|
||||
*/
|
||||
const findFileById = async (file_id, options = {}) => {
|
||||
return await File.findOne({ file_id, ...options }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves files matching a given filter, sorted by the most recently updated.
|
||||
* @param {Object} filter - The filter criteria to apply.
|
||||
* @param {Object} [_sortOptions] - Optional sort parameters.
|
||||
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
|
||||
* Default excludes the 'text' field.
|
||||
* @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents.
|
||||
*/
|
||||
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
|
||||
const sortOptions = { updatedAt: -1, ..._sortOptions };
|
||||
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves tool files (files that are embedded or have a fileIdentifier) from an array of file IDs.
|
||||
* Note: execute_code files are handled separately by getCodeGeneratedFiles.
|
||||
* @param {string[]} fileIds - Array of file_id strings to search for
|
||||
* @param {Set<EToolResources>} toolResourceSet - Optional filter for tool resources
|
||||
* @returns {Promise<Array<MongoFile>>} Files that match the criteria
|
||||
*/
|
||||
const getToolFilesByIds = async (fileIds, toolResourceSet) => {
|
||||
if (!fileIds || !fileIds.length || !toolResourceSet?.size) {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
const orConditions = [];
|
||||
|
||||
if (toolResourceSet.has(EToolResources.context)) {
|
||||
orConditions.push({ text: { $exists: true, $ne: null }, context: FileContext.agents });
|
||||
}
|
||||
if (toolResourceSet.has(EToolResources.file_search)) {
|
||||
orConditions.push({ embedded: true });
|
||||
}
|
||||
|
||||
if (orConditions.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const filter = {
|
||||
file_id: { $in: fileIds },
|
||||
context: { $ne: FileContext.execute_code }, // Exclude code-generated files
|
||||
$or: orConditions,
|
||||
};
|
||||
|
||||
const selectFields = { text: 0 };
|
||||
const sortOptions = { updatedAt: -1 };
|
||||
|
||||
return await getFiles(filter, sortOptions, selectFields);
|
||||
} catch (error) {
|
||||
logger.error('[getToolFilesByIds] Error retrieving tool files:', error);
|
||||
throw new Error('Error retrieving tool files');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves files generated by code execution for a given conversation.
|
||||
* These files are stored locally with fileIdentifier metadata for code env re-upload.
|
||||
* @param {string} conversationId - The conversation ID to search for
|
||||
* @param {string[]} [messageIds] - Optional array of messageIds to filter by (for linear thread filtering)
|
||||
* @returns {Promise<Array<MongoFile>>} Files generated by code execution in the conversation
|
||||
*/
|
||||
const getCodeGeneratedFiles = async (conversationId, messageIds) => {
|
||||
if (!conversationId) {
|
||||
return [];
|
||||
}
|
||||
|
||||
/** messageIds are required for proper thread filtering of code-generated files */
|
||||
if (!messageIds || messageIds.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
const filter = {
|
||||
conversationId,
|
||||
context: FileContext.execute_code,
|
||||
messageId: { $exists: true, $in: messageIds },
|
||||
'metadata.fileIdentifier': { $exists: true },
|
||||
};
|
||||
|
||||
const selectFields = { text: 0 };
|
||||
const sortOptions = { createdAt: 1 };
|
||||
|
||||
return await getFiles(filter, sortOptions, selectFields);
|
||||
} catch (error) {
|
||||
logger.error('[getCodeGeneratedFiles] Error retrieving code generated files:', error);
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves user-uploaded execute_code files (not code-generated) by their file IDs.
|
||||
* These are files with fileIdentifier metadata but context is NOT execute_code (e.g., agents or message_attachment).
|
||||
* File IDs should be collected from message.files arrays in the current thread.
|
||||
* @param {string[]} fileIds - Array of file IDs to fetch (from message.files in the thread)
|
||||
* @returns {Promise<Array<MongoFile>>} User-uploaded execute_code files
|
||||
*/
|
||||
const getUserCodeFiles = async (fileIds) => {
|
||||
if (!fileIds || fileIds.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
const filter = {
|
||||
file_id: { $in: fileIds },
|
||||
context: { $ne: FileContext.execute_code },
|
||||
'metadata.fileIdentifier': { $exists: true },
|
||||
};
|
||||
|
||||
const selectFields = { text: 0 };
|
||||
const sortOptions = { createdAt: 1 };
|
||||
|
||||
return await getFiles(filter, sortOptions, selectFields);
|
||||
} catch (error) {
|
||||
logger.error('[getUserCodeFiles] Error retrieving user code files:', error);
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a new file with a TTL of 1 hour.
|
||||
* @param {MongoFile} data - The file data to be created, must contain file_id.
|
||||
* @param {boolean} disableTTL - Whether to disable the TTL.
|
||||
* @returns {Promise<MongoFile>} A promise that resolves to the created file document.
|
||||
*/
|
||||
const createFile = async (data, disableTTL) => {
|
||||
const fileData = {
|
||||
...data,
|
||||
expiresAt: new Date(Date.now() + 3600 * 1000),
|
||||
};
|
||||
|
||||
if (disableTTL) {
|
||||
delete fileData.expiresAt;
|
||||
}
|
||||
|
||||
return await File.findOneAndUpdate({ file_id: data.file_id }, fileData, {
|
||||
new: true,
|
||||
upsert: true,
|
||||
}).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates a file identified by file_id with new data and removes the TTL.
|
||||
* @param {MongoFile} data - The data to update, must contain file_id.
|
||||
* @returns {Promise<MongoFile>} A promise that resolves to the updated file document.
|
||||
*/
|
||||
const updateFile = async (data) => {
|
||||
const { file_id, ...update } = data;
|
||||
const updateOperation = {
|
||||
$set: update,
|
||||
$unset: { expiresAt: '' }, // Remove the expiresAt field to prevent TTL
|
||||
};
|
||||
return await File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Increments the usage of a file identified by file_id.
|
||||
* @param {MongoFile} data - The data to update, must contain file_id and the increment value for usage.
|
||||
* @returns {Promise<MongoFile>} A promise that resolves to the updated file document.
|
||||
*/
|
||||
const updateFileUsage = async (data) => {
|
||||
const { file_id, inc = 1 } = data;
|
||||
const updateOperation = {
|
||||
$inc: { usage: inc },
|
||||
$unset: { expiresAt: '', temp_file_id: '' },
|
||||
};
|
||||
return await File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes a file identified by file_id.
|
||||
* @param {string} file_id - The unique identifier of the file to delete.
|
||||
* @returns {Promise<MongoFile>} A promise that resolves to the deleted file document or null.
|
||||
*/
|
||||
const deleteFile = async (file_id) => {
|
||||
return await File.findOneAndDelete({ file_id }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes a file identified by a filter.
|
||||
* @param {object} filter - The filter criteria to apply.
|
||||
* @returns {Promise<MongoFile>} A promise that resolves to the deleted file document or null.
|
||||
*/
|
||||
const deleteFileByFilter = async (filter) => {
|
||||
return await File.findOneAndDelete(filter).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes multiple files identified by an array of file_ids.
|
||||
* @param {Array<string>} file_ids - The unique identifiers of the files to delete.
|
||||
* @returns {Promise<Object>} A promise that resolves to the result of the deletion operation.
|
||||
*/
|
||||
const deleteFiles = async (file_ids, user) => {
|
||||
let deleteQuery = { file_id: { $in: file_ids } };
|
||||
if (user) {
|
||||
deleteQuery = { user: user };
|
||||
}
|
||||
return await File.deleteMany(deleteQuery);
|
||||
};
|
||||
|
||||
/**
|
||||
* Batch updates files with new signed URLs in MongoDB
|
||||
*
|
||||
* @param {MongoFile[]} updates - Array of updates in the format { file_id, filepath }
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function batchUpdateFiles(updates) {
|
||||
if (!updates || updates.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bulkOperations = updates.map((update) => ({
|
||||
updateOne: {
|
||||
filter: { file_id: update.file_id },
|
||||
update: { $set: { filepath: update.filepath } },
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await File.bulkWrite(bulkOperations);
|
||||
logger.info(`Updated ${result.modifiedCount} files with new S3 URLs`);
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
findFileById,
|
||||
getFiles,
|
||||
getToolFilesByIds,
|
||||
getCodeGeneratedFiles,
|
||||
getUserCodeFiles,
|
||||
createFile,
|
||||
updateFile,
|
||||
updateFileUsage,
|
||||
deleteFile,
|
||||
deleteFiles,
|
||||
deleteFileByFilter,
|
||||
batchUpdateFiles,
|
||||
};
|
||||
|
|
@ -1,629 +0,0 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { createModels, createMethods } = require('@librechat/data-schemas');
|
||||
const {
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
AccessRoleIds,
|
||||
PrincipalType,
|
||||
} = require('librechat-data-provider');
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
const { createAgent } = require('./Agent');
|
||||
|
||||
let File;
|
||||
let Agent;
|
||||
let AclEntry;
|
||||
let User;
|
||||
let modelsToCleanup = [];
|
||||
let methods;
|
||||
let getFiles;
|
||||
let createFile;
|
||||
let seedDefaultRoles;
|
||||
|
||||
describe('File Access Control', () => {
|
||||
let mongoServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
// Initialize all models
|
||||
const models = createModels(mongoose);
|
||||
|
||||
// Track which models we're adding
|
||||
modelsToCleanup = Object.keys(models);
|
||||
|
||||
// Register models on mongoose.models so methods can access them
|
||||
const dbModels = require('~/db/models');
|
||||
Object.assign(mongoose.models, dbModels);
|
||||
|
||||
File = dbModels.File;
|
||||
Agent = dbModels.Agent;
|
||||
AclEntry = dbModels.AclEntry;
|
||||
User = dbModels.User;
|
||||
|
||||
// Create methods from data-schemas (includes file methods)
|
||||
methods = createMethods(mongoose);
|
||||
getFiles = methods.getFiles;
|
||||
createFile = methods.createFile;
|
||||
seedDefaultRoles = methods.seedDefaultRoles;
|
||||
|
||||
// Seed default roles
|
||||
await seedDefaultRoles();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Clean up all collections before disconnecting
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
|
||||
// Clear only the models we added
|
||||
for (const modelName of modelsToCleanup) {
|
||||
if (mongoose.models[modelName]) {
|
||||
delete mongoose.models[modelName];
|
||||
}
|
||||
}
|
||||
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await File.deleteMany({});
|
||||
await Agent.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
await User.deleteMany({});
|
||||
// Don't delete AccessRole as they are seeded defaults needed for tests
|
||||
});
|
||||
|
||||
describe('hasAccessToFilesViaAgent', () => {
|
||||
it('should efficiently check access for multiple files at once', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4(), uuidv4(), uuidv4()];
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create files
|
||||
for (const fileId of fileIds) {
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId,
|
||||
filename: `file-${fileId}.txt`,
|
||||
filepath: `/uploads/${fileId}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Create agent with only first two files attached
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileIds[0], fileIds[1]],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant EDIT permission to user on the agent
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Check access for all files
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId: agent.id, // Use agent.id which is the custom UUID
|
||||
});
|
||||
|
||||
// Should have access only to the first two files
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||
expect(accessMap.get(fileIds[2])).toBe(false);
|
||||
expect(accessMap.get(fileIds[3])).toBe(false);
|
||||
});
|
||||
|
||||
it('should grant access to all files when user is the agent author', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
|
||||
|
||||
// Create author user
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create agent
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileIds[0]], // Only one file attached
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Check access as the author
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: authorId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId,
|
||||
});
|
||||
|
||||
// Author should have access to all files
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||
expect(accessMap.get(fileIds[2])).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle non-existent agent gracefully', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
// Create user
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId: 'non-existent-agent',
|
||||
});
|
||||
|
||||
// Should have no access to any files
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
|
||||
it('should deny access when user only has VIEW permission and needs access for deletion', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create agent with files
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'View-Only Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: fileIds,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant only VIEW permission to user on the agent
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Check access for files
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId,
|
||||
isDelete: true,
|
||||
});
|
||||
|
||||
// Should have no access to any files when only VIEW permission
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
|
||||
it('should grant access when user has VIEW permission', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create agent with files
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'View-Only Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: fileIds,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant only VIEW permission to user on the agent
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Check access for files
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId,
|
||||
});
|
||||
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getFiles with agent access control', () => {
|
||||
test('should return files owned by user and files accessible through agent', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
const ownedFileId = `file_${uuidv4()}`;
|
||||
const sharedFileId = `file_${uuidv4()}`;
|
||||
const inaccessibleFileId = `file_${uuidv4()}`;
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create agent with shared file
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Shared Agent',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [sharedFileId],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant EDIT permission to user on the agent
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Create files
|
||||
await createFile({
|
||||
file_id: ownedFileId,
|
||||
user: userId,
|
||||
filename: 'owned.txt',
|
||||
filepath: '/uploads/owned.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
|
||||
await createFile({
|
||||
file_id: sharedFileId,
|
||||
user: authorId,
|
||||
filename: 'shared.txt',
|
||||
filepath: '/uploads/shared.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 200,
|
||||
embedded: true,
|
||||
});
|
||||
|
||||
await createFile({
|
||||
file_id: inaccessibleFileId,
|
||||
user: authorId,
|
||||
filename: 'inaccessible.txt',
|
||||
filepath: '/uploads/inaccessible.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 300,
|
||||
});
|
||||
|
||||
// Get all files first
|
||||
const allFiles = await getFiles(
|
||||
{ file_id: { $in: [ownedFileId, sharedFileId, inaccessibleFileId] } },
|
||||
null,
|
||||
{ text: 0 },
|
||||
);
|
||||
|
||||
// Then filter by access control
|
||||
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
|
||||
const files = await filterFilesByAgentAccess({
|
||||
files: allFiles,
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
agentId,
|
||||
});
|
||||
|
||||
expect(files).toHaveLength(2);
|
||||
expect(files.map((f) => f.file_id)).toContain(ownedFileId);
|
||||
expect(files.map((f) => f.file_id)).toContain(sharedFileId);
|
||||
expect(files.map((f) => f.file_id)).not.toContain(inaccessibleFileId);
|
||||
});
|
||||
|
||||
test('should return all files when no userId/agentId provided', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const fileId1 = `file_${uuidv4()}`;
|
||||
const fileId2 = `file_${uuidv4()}`;
|
||||
|
||||
await createFile({
|
||||
file_id: fileId1,
|
||||
user: userId,
|
||||
filename: 'file1.txt',
|
||||
filepath: '/uploads/file1.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
|
||||
await createFile({
|
||||
file_id: fileId2,
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
filename: 'file2.txt',
|
||||
filepath: '/uploads/file2.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 200,
|
||||
});
|
||||
|
||||
const files = await getFiles({ file_id: { $in: [fileId1, fileId2] } });
|
||||
expect(files).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Role-based file permissions', () => {
|
||||
it('should optimize permission checks when role is provided', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
role: 'ADMIN', // User has ADMIN role
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create files
|
||||
for (const fileId of fileIds) {
|
||||
await createFile({
|
||||
file_id: fileId,
|
||||
user: authorId,
|
||||
filename: `${fileId}.txt`,
|
||||
filepath: `/uploads/${fileId}.txt`,
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
}
|
||||
|
||||
// Create agent with files
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: fileIds,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant permission to ADMIN role
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.ROLE,
|
||||
principalId: 'ADMIN',
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Check access with role provided (should avoid DB query)
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMapWithRole = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: 'ADMIN',
|
||||
fileIds,
|
||||
agentId: agent.id,
|
||||
});
|
||||
|
||||
// User should have access through their ADMIN role
|
||||
expect(accessMapWithRole.get(fileIds[0])).toBe(true);
|
||||
expect(accessMapWithRole.get(fileIds[1])).toBe(true);
|
||||
|
||||
// Check access without role (will query DB to get user's role)
|
||||
const accessMapWithoutRole = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
fileIds,
|
||||
agentId: agent.id,
|
||||
});
|
||||
|
||||
// Should have same result
|
||||
expect(accessMapWithoutRole.get(fileIds[0])).toBe(true);
|
||||
expect(accessMapWithoutRole.get(fileIds[1])).toBe(true);
|
||||
});
|
||||
|
||||
it('should deny access when user role changes', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileId = uuidv4();
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
role: 'EDITOR',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create file
|
||||
await createFile({
|
||||
file_id: fileId,
|
||||
user: authorId,
|
||||
filename: 'test.txt',
|
||||
filepath: '/uploads/test.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
|
||||
// Create agent
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant permission to EDITOR role only
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.ROLE,
|
||||
principalId: 'EDITOR',
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
|
||||
// Check with EDITOR role - should have access
|
||||
const accessAsEditor = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: 'EDITOR',
|
||||
fileIds: [fileId],
|
||||
agentId: agent.id,
|
||||
});
|
||||
expect(accessAsEditor.get(fileId)).toBe(true);
|
||||
|
||||
// Simulate role change to USER - should lose access
|
||||
const accessAsUser = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds: [fileId],
|
||||
agentId: agent.id,
|
||||
});
|
||||
expect(accessAsUser.get(fileId)).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,372 +0,0 @@
|
|||
const { z } = require('zod');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const { Message } = require('~/db/models');
|
||||
|
||||
const idSchema = z.string().uuid();
|
||||
|
||||
/**
|
||||
* Saves a message in the database.
|
||||
*
|
||||
* @async
|
||||
* @function saveMessage
|
||||
* @param {ServerRequest} req - The request object containing user information.
|
||||
* @param {Object} params - The message data object.
|
||||
* @param {string} params.endpoint - The endpoint where the message originated.
|
||||
* @param {string} params.iconURL - The URL of the sender's icon.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.newMessageId - The new unique identifier for the message (if applicable).
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
|
||||
* @param {string} params.sender - The identifier of the sender.
|
||||
* @param {string} params.text - The text content of the message.
|
||||
* @param {boolean} params.isCreatedByUser - Indicates if the message was created by the user.
|
||||
* @param {string} [params.error] - Any error associated with the message.
|
||||
* @param {boolean} [params.unfinished] - Indicates if the message is unfinished.
|
||||
* @param {Object[]} [params.files] - An array of files associated with the message.
|
||||
* @param {string} [params.finish_reason] - Reason for finishing the message.
|
||||
* @param {number} [params.tokenCount] - The number of tokens in the message.
|
||||
* @param {string} [params.plugin] - Plugin associated with the message.
|
||||
* @param {string[]} [params.plugins] - An array of plugins associated with the message.
|
||||
* @param {string} [params.model] - The model used to generate the message.
|
||||
* @param {Object} [metadata] - Additional metadata for this operation
|
||||
* @param {string} [metadata.context] - The context of the operation
|
||||
* @returns {Promise<TMessage>} The updated or newly inserted message document.
|
||||
* @throws {Error} If there is an error in saving the message.
|
||||
*/
|
||||
async function saveMessage(req, params, metadata) {
|
||||
if (!req?.user?.id) {
|
||||
throw new Error('User not authenticated');
|
||||
}
|
||||
|
||||
const validConvoId = idSchema.safeParse(params.conversationId);
|
||||
if (!validConvoId.success) {
|
||||
logger.warn(`Invalid conversation ID: ${params.conversationId}`);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
logger.info(`---Invalid conversation ID Params: ${JSON.stringify(params, null, 2)}`);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const update = {
|
||||
...params,
|
||||
user: req.user.id,
|
||||
messageId: params.newMessageId || params.messageId,
|
||||
};
|
||||
|
||||
if (req?.body?.isTemporary) {
|
||||
try {
|
||||
const appConfig = req.config;
|
||||
update.expiredAt = createTempChatExpirationDate(appConfig?.interfaceConfig);
|
||||
} catch (err) {
|
||||
logger.error('Error creating temporary chat expiration date:', err);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
update.expiredAt = null;
|
||||
}
|
||||
} else {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
if (update.tokenCount != null && isNaN(update.tokenCount)) {
|
||||
logger.warn(
|
||||
`Resetting invalid \`tokenCount\` for message \`${params.messageId}\`: ${update.tokenCount}`,
|
||||
);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
update.tokenCount = 0;
|
||||
}
|
||||
const message = await Message.findOneAndUpdate(
|
||||
{ messageId: params.messageId, user: req.user.id },
|
||||
update,
|
||||
{ upsert: true, new: true },
|
||||
);
|
||||
|
||||
return message.toObject();
|
||||
} catch (err) {
|
||||
logger.error('Error saving message:', err);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
|
||||
// Check if this is a duplicate key error (MongoDB error code 11000)
|
||||
if (err.code === 11000 && err.message.includes('duplicate key error')) {
|
||||
// Log the duplicate key error but don't crash the application
|
||||
logger.warn(`Duplicate messageId detected: ${params.messageId}. Continuing execution.`);
|
||||
|
||||
try {
|
||||
// Try to find the existing message with this ID
|
||||
const existingMessage = await Message.findOne({
|
||||
messageId: params.messageId,
|
||||
user: req.user.id,
|
||||
});
|
||||
|
||||
// If we found it, return it
|
||||
if (existingMessage) {
|
||||
return existingMessage.toObject();
|
||||
}
|
||||
|
||||
// If we can't find it (unlikely but possible in race conditions)
|
||||
return {
|
||||
...params,
|
||||
messageId: params.messageId,
|
||||
user: req.user.id,
|
||||
};
|
||||
} catch (findError) {
|
||||
// If the findOne also fails, log it but don't crash
|
||||
logger.warn(
|
||||
`Could not retrieve existing message with ID ${params.messageId}: ${findError.message}`,
|
||||
);
|
||||
return {
|
||||
...params,
|
||||
messageId: params.messageId,
|
||||
user: req.user.id,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
throw err; // Re-throw other errors
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves multiple messages in the database in bulk.
|
||||
*
|
||||
* @async
|
||||
* @function bulkSaveMessages
|
||||
* @param {Object[]} messages - An array of message objects to save.
|
||||
* @param {boolean} [overrideTimestamp=false] - Indicates whether to override the timestamps of the messages. Defaults to false.
|
||||
* @returns {Promise<Object>} The result of the bulk write operation.
|
||||
* @throws {Error} If there is an error in saving messages in bulk.
|
||||
*/
|
||||
async function bulkSaveMessages(messages, overrideTimestamp = false) {
|
||||
try {
|
||||
const bulkOps = messages.map((message) => ({
|
||||
updateOne: {
|
||||
filter: { messageId: message.messageId },
|
||||
update: message,
|
||||
timestamps: !overrideTimestamp,
|
||||
upsert: true,
|
||||
},
|
||||
}));
|
||||
const result = await Message.bulkWrite(bulkOps);
|
||||
return result;
|
||||
} catch (err) {
|
||||
logger.error('Error saving messages in bulk:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Records a message in the database.
|
||||
*
|
||||
* @async
|
||||
* @function recordMessage
|
||||
* @param {Object} params - The message data object.
|
||||
* @param {string} params.user - The identifier of the user.
|
||||
* @param {string} params.endpoint - The endpoint where the message originated.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
|
||||
* @param {Partial<TMessage>} rest - Any additional properties from the TMessage typedef not explicitly listed.
|
||||
* @returns {Promise<Object>} The updated or newly inserted message document.
|
||||
* @throws {Error} If there is an error in saving the message.
|
||||
*/
|
||||
async function recordMessage({
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest
|
||||
}) {
|
||||
try {
|
||||
// No parsing of convoId as may use threadId
|
||||
const message = {
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest,
|
||||
};
|
||||
|
||||
return await Message.findOneAndUpdate({ user, messageId }, message, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error('Error recording message:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the text of a message.
|
||||
*
|
||||
* @async
|
||||
* @function updateMessageText
|
||||
* @param {Object} params - The update data object.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.text - The new text content of the message.
|
||||
* @returns {Promise<void>}
|
||||
* @throws {Error} If there is an error in updating the message text.
|
||||
*/
|
||||
async function updateMessageText(req, { messageId, text }) {
|
||||
try {
|
||||
await Message.updateOne({ messageId, user: req.user.id }, { text });
|
||||
} catch (err) {
|
||||
logger.error('Error updating message text:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a message.
|
||||
*
|
||||
* @async
|
||||
* @function updateMessage
|
||||
* @param {Object} req - The request object.
|
||||
* @param {Object} message - The message object containing update data.
|
||||
* @param {string} message.messageId - The unique identifier for the message.
|
||||
* @param {string} [message.text] - The new text content of the message.
|
||||
* @param {Object[]} [message.files] - The files associated with the message.
|
||||
* @param {boolean} [message.isCreatedByUser] - Indicates if the message was created by the user.
|
||||
* @param {string} [message.sender] - The identifier of the sender.
|
||||
* @param {number} [message.tokenCount] - The number of tokens in the message.
|
||||
* @param {Object} [metadata] - The operation metadata
|
||||
* @param {string} [metadata.context] - The operation metadata
|
||||
* @returns {Promise<TMessage>} The updated message document.
|
||||
* @throws {Error} If there is an error in updating the message or if the message is not found.
|
||||
*/
|
||||
async function updateMessage(req, message, metadata) {
|
||||
try {
|
||||
const { messageId, ...update } = message;
|
||||
const updatedMessage = await Message.findOneAndUpdate(
|
||||
{ messageId, user: req.user.id },
|
||||
update,
|
||||
{
|
||||
new: true,
|
||||
},
|
||||
);
|
||||
|
||||
if (!updatedMessage) {
|
||||
throw new Error('Message not found or user not authorized.');
|
||||
}
|
||||
|
||||
return {
|
||||
messageId: updatedMessage.messageId,
|
||||
conversationId: updatedMessage.conversationId,
|
||||
parentMessageId: updatedMessage.parentMessageId,
|
||||
sender: updatedMessage.sender,
|
||||
text: updatedMessage.text,
|
||||
isCreatedByUser: updatedMessage.isCreatedByUser,
|
||||
tokenCount: updatedMessage.tokenCount,
|
||||
feedback: updatedMessage.feedback,
|
||||
};
|
||||
} catch (err) {
|
||||
logger.error('Error updating message:', err);
|
||||
if (metadata && metadata?.context) {
|
||||
logger.info(`---\`updateMessage\` context: ${metadata.context}`);
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes messages in a conversation since a specific message.
|
||||
*
|
||||
* @async
|
||||
* @function deleteMessagesSince
|
||||
* @param {Object} params - The parameters object.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @returns {Promise<Number>} The number of deleted messages.
|
||||
* @throws {Error} If there is an error in deleting messages.
|
||||
*/
|
||||
async function deleteMessagesSince(req, { messageId, conversationId }) {
|
||||
try {
|
||||
const message = await Message.findOne({ messageId, user: req.user.id }).lean();
|
||||
|
||||
if (message) {
|
||||
const query = Message.find({ conversationId, user: req.user.id });
|
||||
return await query.deleteMany({
|
||||
createdAt: { $gt: message.createdAt },
|
||||
});
|
||||
}
|
||||
return undefined;
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves messages from the database.
|
||||
* @async
|
||||
* @function getMessages
|
||||
* @param {Record<string, unknown>} filter - The filter criteria.
|
||||
* @param {string | undefined} [select] - The fields to select.
|
||||
* @returns {Promise<TMessage[]>} The messages that match the filter criteria.
|
||||
* @throws {Error} If there is an error in retrieving messages.
|
||||
*/
|
||||
async function getMessages(filter, select) {
|
||||
try {
|
||||
if (select) {
|
||||
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
|
||||
}
|
||||
|
||||
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a single message from the database.
|
||||
* @async
|
||||
* @function getMessage
|
||||
* @param {{ user: string, messageId: string }} params - The search parameters
|
||||
* @returns {Promise<TMessage | null>} The message that matches the criteria or null if not found
|
||||
* @throws {Error} If there is an error in retrieving the message
|
||||
*/
|
||||
async function getMessage({ user, messageId }) {
|
||||
try {
|
||||
return await Message.findOne({
|
||||
user,
|
||||
messageId,
|
||||
}).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting message:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes messages from the database.
|
||||
*
|
||||
* @async
|
||||
* @function deleteMessages
|
||||
* @param {import('mongoose').FilterQuery<import('mongoose').Document>} filter - The filter criteria to find messages to delete.
|
||||
* @returns {Promise<import('mongoose').DeleteResult>} The metadata with count of deleted messages.
|
||||
* @throws {Error} If there is an error in deleting messages.
|
||||
*/
|
||||
async function deleteMessages(filter) {
|
||||
try {
|
||||
return await Message.deleteMany(filter);
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
saveMessage,
|
||||
bulkSaveMessages,
|
||||
recordMessage,
|
||||
updateMessageText,
|
||||
updateMessage,
|
||||
deleteMessagesSince,
|
||||
getMessages,
|
||||
getMessage,
|
||||
deleteMessages,
|
||||
};
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Preset } = require('~/db/models');
|
||||
|
||||
const getPreset = async (user, presetId) => {
|
||||
try {
|
||||
return await Preset.findOne({ user, presetId }).lean();
|
||||
} catch (error) {
|
||||
logger.error('[getPreset] Error getting single preset', error);
|
||||
return { message: 'Error getting single preset' };
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getPreset,
|
||||
getPresets: async (user, filter) => {
|
||||
try {
|
||||
const presets = await Preset.find({ ...filter, user }).lean();
|
||||
const defaultValue = 10000;
|
||||
|
||||
presets.sort((a, b) => {
|
||||
let orderA = a.order !== undefined ? a.order : defaultValue;
|
||||
let orderB = b.order !== undefined ? b.order : defaultValue;
|
||||
|
||||
if (orderA !== orderB) {
|
||||
return orderA - orderB;
|
||||
}
|
||||
|
||||
return b.updatedAt - a.updatedAt;
|
||||
});
|
||||
|
||||
return presets;
|
||||
} catch (error) {
|
||||
logger.error('[getPresets] Error getting presets', error);
|
||||
return { message: 'Error retrieving presets' };
|
||||
}
|
||||
},
|
||||
savePreset: async (user, { presetId, newPresetId, defaultPreset, ...preset }) => {
|
||||
try {
|
||||
const setter = { $set: {} };
|
||||
const { user: _, ...cleanPreset } = preset;
|
||||
const update = { presetId, ...cleanPreset };
|
||||
if (preset.tools && Array.isArray(preset.tools)) {
|
||||
update.tools =
|
||||
preset.tools
|
||||
.map((tool) => tool?.pluginKey ?? tool)
|
||||
.filter((toolName) => typeof toolName === 'string') ?? [];
|
||||
}
|
||||
if (newPresetId) {
|
||||
update.presetId = newPresetId;
|
||||
}
|
||||
|
||||
if (defaultPreset) {
|
||||
update.defaultPreset = defaultPreset;
|
||||
update.order = 0;
|
||||
|
||||
const currentDefault = await Preset.findOne({ defaultPreset: true, user });
|
||||
|
||||
if (currentDefault && currentDefault.presetId !== presetId) {
|
||||
await Preset.findByIdAndUpdate(currentDefault._id, {
|
||||
$unset: { defaultPreset: '', order: '' },
|
||||
});
|
||||
}
|
||||
} else if (defaultPreset === false) {
|
||||
update.defaultPreset = undefined;
|
||||
update.order = undefined;
|
||||
setter['$unset'] = { defaultPreset: '', order: '' };
|
||||
}
|
||||
|
||||
setter.$set = update;
|
||||
return await Preset.findOneAndUpdate({ presetId, user }, setter, { new: true, upsert: true });
|
||||
} catch (error) {
|
||||
logger.error('[savePreset] Error saving preset', error);
|
||||
return { message: 'Error saving preset' };
|
||||
}
|
||||
},
|
||||
deletePresets: async (user, filter) => {
|
||||
// let toRemove = await Preset.find({ ...filter, user }).select('presetId');
|
||||
// const ids = toRemove.map((instance) => instance.presetId);
|
||||
let deleteCount = await Preset.deleteMany({ ...filter, user });
|
||||
return deleteCount;
|
||||
},
|
||||
};
|
||||
|
|
@ -1,133 +0,0 @@
|
|||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
const { Project } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* Retrieve a project by ID and convert the found project document to a plain object.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to find and return as a plain object.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<IMongoProject>} A plain object representing the project document, or `null` if no project is found.
|
||||
*/
|
||||
const getProjectById = async function (projectId, fieldsToSelect = null) {
|
||||
const query = Project.findById(projectId);
|
||||
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
}
|
||||
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieve a project by name and convert the found project document to a plain object.
|
||||
* If the project with the given name doesn't exist and the name is "instance", create it and return the lean version.
|
||||
*
|
||||
* @param {string} projectName - The name of the project to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<IMongoProject>} A plain object representing the project document.
|
||||
*/
|
||||
const getProjectByName = async function (projectName, fieldsToSelect = null) {
|
||||
const query = { name: projectName };
|
||||
const update = { $setOnInsert: { name: projectName } };
|
||||
const options = {
|
||||
new: true,
|
||||
upsert: projectName === GLOBAL_PROJECT_NAME,
|
||||
lean: true,
|
||||
select: fieldsToSelect,
|
||||
};
|
||||
|
||||
return await Project.findOneAndUpdate(query, update, options);
|
||||
};
|
||||
|
||||
/**
|
||||
* Add an array of prompt group IDs to a project's promptGroupIds array, ensuring uniqueness.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} promptGroupIds - The array of prompt group IDs to add to the project.
|
||||
* @returns {Promise<IMongoProject>} The updated project document.
|
||||
*/
|
||||
const addGroupIdsToProject = async function (projectId, promptGroupIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $addToSet: { promptGroupIds: { $each: promptGroupIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove an array of prompt group IDs from a project's promptGroupIds array.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} promptGroupIds - The array of prompt group IDs to remove from the project.
|
||||
* @returns {Promise<IMongoProject>} The updated project document.
|
||||
*/
|
||||
const removeGroupIdsFromProject = async function (projectId, promptGroupIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $pull: { promptGroupIds: { $in: promptGroupIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove a prompt group ID from all projects.
|
||||
*
|
||||
* @param {string} promptGroupId - The ID of the prompt group to remove from projects.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const removeGroupFromAllProjects = async (promptGroupId) => {
|
||||
await Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } });
|
||||
};
|
||||
|
||||
/**
|
||||
* Add an array of agent IDs to a project's agentIds array, ensuring uniqueness.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} agentIds - The array of agent IDs to add to the project.
|
||||
* @returns {Promise<IMongoProject>} The updated project document.
|
||||
*/
|
||||
const addAgentIdsToProject = async function (projectId, agentIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $addToSet: { agentIds: { $each: agentIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove an array of agent IDs from a project's agentIds array.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} agentIds - The array of agent IDs to remove from the project.
|
||||
* @returns {Promise<IMongoProject>} The updated project document.
|
||||
*/
|
||||
const removeAgentIdsFromProject = async function (projectId, agentIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $pull: { agentIds: { $in: agentIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove an agent ID from all projects.
|
||||
*
|
||||
* @param {string} agentId - The ID of the agent to remove from projects.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const removeAgentFromAllProjects = async (agentId) => {
|
||||
await Project.updateMany({}, { $pull: { agentIds: agentId } });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getProjectById,
|
||||
getProjectByName,
|
||||
/* prompts */
|
||||
addGroupIdsToProject,
|
||||
removeGroupIdsFromProject,
|
||||
removeGroupFromAllProjects,
|
||||
/* agents */
|
||||
addAgentIdsToProject,
|
||||
removeAgentIdsFromProject,
|
||||
removeAgentFromAllProjects,
|
||||
};
|
||||
|
|
@ -1,708 +0,0 @@
|
|||
const { ObjectId } = require('mongodb');
|
||||
const { escapeRegExp } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Constants,
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
SystemCategories,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
removeGroupFromAllProjects,
|
||||
removeGroupIdsFromProject,
|
||||
addGroupIdsToProject,
|
||||
getProjectByName,
|
||||
} = require('./Project');
|
||||
const { removeAllPermissions } = require('~/server/services/PermissionService');
|
||||
const { PromptGroup, Prompt, AclEntry } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* Create a pipeline for the aggregation to get prompt groups
|
||||
* @param {Object} query
|
||||
* @param {number} skip
|
||||
* @param {number} limit
|
||||
* @returns {[Object]} - The pipeline for the aggregation
|
||||
*/
|
||||
const createGroupPipeline = (query, skip, limit) => {
|
||||
return [
|
||||
{ $match: query },
|
||||
{ $sort: { createdAt: -1 } },
|
||||
{ $skip: skip },
|
||||
{ $limit: limit },
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project: {
|
||||
name: 1,
|
||||
numberOfGenerations: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
projectIds: 1,
|
||||
productionId: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
// 'productionPrompt._id': 1,
|
||||
// 'productionPrompt.type': 1,
|
||||
},
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
/**
|
||||
* Create a pipeline for the aggregation to get all prompt groups
|
||||
* @param {Object} query
|
||||
* @param {Partial<MongoPromptGroup>} $project
|
||||
* @returns {[Object]} - The pipeline for the aggregation
|
||||
*/
|
||||
const createAllGroupsPipeline = (
|
||||
query,
|
||||
$project = {
|
||||
name: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
command: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
},
|
||||
) => {
|
||||
return [
|
||||
{ $match: query },
|
||||
{ $sort: { createdAt: -1 } },
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project,
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
/**
|
||||
* Get all prompt groups with filters
|
||||
* @param {ServerRequest} req
|
||||
* @param {TPromptGroupsWithFilterRequest} filter
|
||||
* @returns {Promise<PromptGroupListResponse>}
|
||||
*/
|
||||
const getAllPromptGroups = async (req, filter) => {
|
||||
try {
|
||||
const { name, ...query } = filter;
|
||||
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
if (name) {
|
||||
query.name = new RegExp(escapeRegExp(name), 'i');
|
||||
}
|
||||
if (!query.category) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||
searchShared = false;
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||
query.category = '';
|
||||
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||
searchSharedOnly = true;
|
||||
delete query.category;
|
||||
}
|
||||
|
||||
let combinedQuery = query;
|
||||
|
||||
if (searchShared) {
|
||||
const project = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, 'promptGroupIds');
|
||||
if (project && project.promptGroupIds && project.promptGroupIds.length > 0) {
|
||||
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
|
||||
delete projectQuery.author;
|
||||
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
|
||||
}
|
||||
}
|
||||
|
||||
const promptGroupsPipeline = createAllGroupsPipeline(combinedQuery);
|
||||
return await PromptGroup.aggregate(promptGroupsPipeline).exec();
|
||||
} catch (error) {
|
||||
console.error('Error getting all prompt groups', error);
|
||||
return { message: 'Error getting all prompt groups' };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get prompt groups with filters
|
||||
* @param {ServerRequest} req
|
||||
* @param {TPromptGroupsWithFilterRequest} filter
|
||||
* @returns {Promise<PromptGroupListResponse>}
|
||||
*/
|
||||
const getPromptGroups = async (req, filter) => {
|
||||
try {
|
||||
const { pageNumber = 1, pageSize = 10, name, ...query } = filter;
|
||||
|
||||
const validatedPageNumber = Math.max(parseInt(pageNumber, 10), 1);
|
||||
const validatedPageSize = Math.max(parseInt(pageSize, 10), 1);
|
||||
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
if (name) {
|
||||
query.name = new RegExp(escapeRegExp(name), 'i');
|
||||
}
|
||||
if (!query.category) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||
searchShared = false;
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||
query.category = '';
|
||||
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||
searchSharedOnly = true;
|
||||
delete query.category;
|
||||
}
|
||||
|
||||
let combinedQuery = query;
|
||||
|
||||
if (searchShared) {
|
||||
// const projects = req.user.projects || []; // TODO: handle multiple projects
|
||||
const project = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, 'promptGroupIds');
|
||||
if (project && project.promptGroupIds && project.promptGroupIds.length > 0) {
|
||||
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
|
||||
delete projectQuery.author;
|
||||
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
|
||||
}
|
||||
}
|
||||
|
||||
const skip = (validatedPageNumber - 1) * validatedPageSize;
|
||||
const limit = validatedPageSize;
|
||||
|
||||
const promptGroupsPipeline = createGroupPipeline(combinedQuery, skip, limit);
|
||||
const totalPromptGroupsPipeline = [{ $match: combinedQuery }, { $count: 'total' }];
|
||||
|
||||
const [promptGroupsResults, totalPromptGroupsResults] = await Promise.all([
|
||||
PromptGroup.aggregate(promptGroupsPipeline).exec(),
|
||||
PromptGroup.aggregate(totalPromptGroupsPipeline).exec(),
|
||||
]);
|
||||
|
||||
const promptGroups = promptGroupsResults;
|
||||
const totalPromptGroups =
|
||||
totalPromptGroupsResults.length > 0 ? totalPromptGroupsResults[0].total : 0;
|
||||
|
||||
return {
|
||||
promptGroups,
|
||||
pageNumber: validatedPageNumber.toString(),
|
||||
pageSize: validatedPageSize.toString(),
|
||||
pages: Math.ceil(totalPromptGroups / validatedPageSize).toString(),
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {Object} fields
|
||||
* @param {string} fields._id
|
||||
* @param {string} fields.author
|
||||
* @param {string} fields.role
|
||||
* @returns {Promise<TDeletePromptGroupResponse>}
|
||||
*/
|
||||
const deletePromptGroup = async ({ _id, author, role }) => {
|
||||
// Build query - with ACL, author is optional
|
||||
const query = { _id };
|
||||
const groupQuery = { groupId: new ObjectId(_id) };
|
||||
|
||||
// Legacy: Add author filter if provided (backward compatibility)
|
||||
if (author && role !== SystemRoles.ADMIN) {
|
||||
query.author = author;
|
||||
groupQuery.author = author;
|
||||
}
|
||||
|
||||
const response = await PromptGroup.deleteOne(query);
|
||||
|
||||
if (!response || response.deletedCount === 0) {
|
||||
throw new Error('Prompt group not found');
|
||||
}
|
||||
|
||||
await Prompt.deleteMany(groupQuery);
|
||||
await removeGroupFromAllProjects(_id);
|
||||
|
||||
try {
|
||||
await removeAllPermissions({ resourceType: ResourceType.PROMPTGROUP, resourceId: _id });
|
||||
} catch (error) {
|
||||
logger.error('Error removing promptGroup permissions:', error);
|
||||
}
|
||||
|
||||
return { message: 'Prompt group deleted successfully' };
|
||||
};
|
||||
|
||||
/**
|
||||
* Get prompt groups by accessible IDs with optional cursor-based pagination.
|
||||
* @param {Object} params - The parameters for getting accessible prompt groups.
|
||||
* @param {Array} [params.accessibleIds] - Array of prompt group ObjectIds the user has ACL access to.
|
||||
* @param {Object} [params.otherParams] - Additional query parameters (including author filter).
|
||||
* @param {number} [params.limit] - Number of prompt groups to return (max 100). If not provided, returns all prompt groups.
|
||||
* @param {string} [params.after] - Cursor for pagination - get prompt groups after this cursor. // base64 encoded JSON string with updatedAt and _id.
|
||||
* @returns {Promise<Object>} A promise that resolves to an object containing the prompt groups data and pagination info.
|
||||
*/
|
||||
async function getListPromptGroupsByAccess({
|
||||
accessibleIds = [],
|
||||
otherParams = {},
|
||||
limit = null,
|
||||
after = null,
|
||||
}) {
|
||||
const isPaginated = limit !== null && limit !== undefined;
|
||||
const normalizedLimit = isPaginated ? Math.min(Math.max(1, parseInt(limit) || 20), 100) : null;
|
||||
|
||||
// Build base query combining ACL accessible prompt groups with other filters
|
||||
const baseQuery = { ...otherParams, _id: { $in: accessibleIds } };
|
||||
|
||||
// Add cursor condition
|
||||
if (after && typeof after === 'string' && after !== 'undefined' && after !== 'null') {
|
||||
try {
|
||||
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
|
||||
const { updatedAt, _id } = cursor;
|
||||
|
||||
const cursorCondition = {
|
||||
$or: [
|
||||
{ updatedAt: { $lt: new Date(updatedAt) } },
|
||||
{ updatedAt: new Date(updatedAt), _id: { $gt: new ObjectId(_id) } },
|
||||
],
|
||||
};
|
||||
|
||||
// Merge cursor condition with base query
|
||||
if (Object.keys(baseQuery).length > 0) {
|
||||
baseQuery.$and = [{ ...baseQuery }, cursorCondition];
|
||||
// Remove the original conditions from baseQuery to avoid duplication
|
||||
Object.keys(baseQuery).forEach((key) => {
|
||||
if (key !== '$and') delete baseQuery[key];
|
||||
});
|
||||
} else {
|
||||
Object.assign(baseQuery, cursorCondition);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Invalid cursor:', error.message);
|
||||
}
|
||||
}
|
||||
|
||||
// Build aggregation pipeline
|
||||
const pipeline = [{ $match: baseQuery }, { $sort: { updatedAt: -1, _id: 1 } }];
|
||||
|
||||
// Only apply limit if pagination is requested
|
||||
if (isPaginated) {
|
||||
pipeline.push({ $limit: normalizedLimit + 1 });
|
||||
}
|
||||
|
||||
// Add lookup for production prompt
|
||||
pipeline.push(
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project: {
|
||||
name: 1,
|
||||
numberOfGenerations: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
projectIds: 1,
|
||||
productionId: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const promptGroups = await PromptGroup.aggregate(pipeline).exec();
|
||||
|
||||
const hasMore = isPaginated ? promptGroups.length > normalizedLimit : false;
|
||||
const data = (isPaginated ? promptGroups.slice(0, normalizedLimit) : promptGroups).map(
|
||||
(group) => {
|
||||
if (group.author) {
|
||||
group.author = group.author.toString();
|
||||
}
|
||||
return group;
|
||||
},
|
||||
);
|
||||
|
||||
// Generate next cursor only if paginated
|
||||
let nextCursor = null;
|
||||
if (isPaginated && hasMore && data.length > 0) {
|
||||
const lastGroup = promptGroups[normalizedLimit - 1];
|
||||
nextCursor = Buffer.from(
|
||||
JSON.stringify({
|
||||
updatedAt: lastGroup.updatedAt.toISOString(),
|
||||
_id: lastGroup._id.toString(),
|
||||
}),
|
||||
).toString('base64');
|
||||
}
|
||||
|
||||
return {
|
||||
object: 'list',
|
||||
data,
|
||||
first_id: data.length > 0 ? data[0]._id.toString() : null,
|
||||
last_id: data.length > 0 ? data[data.length - 1]._id.toString() : null,
|
||||
has_more: hasMore,
|
||||
after: nextCursor,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getPromptGroups,
|
||||
deletePromptGroup,
|
||||
getAllPromptGroups,
|
||||
getListPromptGroupsByAccess,
|
||||
/**
|
||||
* Create a prompt and its respective group
|
||||
* @param {TCreatePromptRecord} saveData
|
||||
* @returns {Promise<TCreatePromptResponse>}
|
||||
*/
|
||||
createPromptGroup: async (saveData) => {
|
||||
try {
|
||||
const { prompt, group, author, authorName } = saveData;
|
||||
|
||||
let newPromptGroup = await PromptGroup.findOneAndUpdate(
|
||||
{ ...group, author, authorName, productionId: null },
|
||||
{ $setOnInsert: { ...group, author, authorName, productionId: null } },
|
||||
{ new: true, upsert: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
const newPrompt = await Prompt.findOneAndUpdate(
|
||||
{ ...prompt, author, groupId: newPromptGroup._id },
|
||||
{ $setOnInsert: { ...prompt, author, groupId: newPromptGroup._id } },
|
||||
{ new: true, upsert: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
newPromptGroup = await PromptGroup.findByIdAndUpdate(
|
||||
newPromptGroup._id,
|
||||
{ productionId: newPrompt._id },
|
||||
{ new: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
return {
|
||||
prompt: newPrompt,
|
||||
group: {
|
||||
...newPromptGroup,
|
||||
productionPrompt: { prompt: newPrompt.prompt },
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error saving prompt group', error);
|
||||
throw new Error('Error saving prompt group');
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Save a prompt
|
||||
* @param {TCreatePromptRecord} saveData
|
||||
* @returns {Promise<TCreatePromptResponse>}
|
||||
*/
|
||||
savePrompt: async (saveData) => {
|
||||
try {
|
||||
const { prompt, author } = saveData;
|
||||
const newPromptData = {
|
||||
...prompt,
|
||||
author,
|
||||
};
|
||||
|
||||
/** @type {TPrompt} */
|
||||
let newPrompt;
|
||||
try {
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
} catch (error) {
|
||||
if (error?.message?.includes('groupId_1_version_1')) {
|
||||
await Prompt.db.collection('prompts').dropIndex('groupId_1_version_1');
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
}
|
||||
|
||||
return { prompt: newPrompt };
|
||||
} catch (error) {
|
||||
logger.error('Error saving prompt', error);
|
||||
return { message: 'Error saving prompt' };
|
||||
}
|
||||
},
|
||||
getPrompts: async (filter) => {
|
||||
try {
|
||||
return await Prompt.find(filter).sort({ createdAt: -1 }).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompts', error);
|
||||
return { message: 'Error getting prompts' };
|
||||
}
|
||||
},
|
||||
getPrompt: async (filter) => {
|
||||
try {
|
||||
if (filter.groupId) {
|
||||
filter.groupId = new ObjectId(filter.groupId);
|
||||
}
|
||||
return await Prompt.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt', error);
|
||||
return { message: 'Error getting prompt' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Get prompt groups with filters
|
||||
* @param {TGetRandomPromptsRequest} filter
|
||||
* @returns {Promise<TGetRandomPromptsResponse>}
|
||||
*/
|
||||
getRandomPromptGroups: async (filter) => {
|
||||
try {
|
||||
const result = await PromptGroup.aggregate([
|
||||
{
|
||||
$match: {
|
||||
category: { $ne: '' },
|
||||
},
|
||||
},
|
||||
{
|
||||
$group: {
|
||||
_id: '$category',
|
||||
promptGroup: { $first: '$$ROOT' },
|
||||
},
|
||||
},
|
||||
{
|
||||
$replaceRoot: { newRoot: '$promptGroup' },
|
||||
},
|
||||
{
|
||||
$sample: { size: +filter.limit + +filter.skip },
|
||||
},
|
||||
{
|
||||
$skip: +filter.skip,
|
||||
},
|
||||
{
|
||||
$limit: +filter.limit,
|
||||
},
|
||||
]);
|
||||
return { prompts: result };
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
},
|
||||
getPromptGroupsWithPrompts: async (filter) => {
|
||||
try {
|
||||
return await PromptGroup.findOne(filter)
|
||||
.populate({
|
||||
path: 'prompts',
|
||||
select: '-_id -__v -user',
|
||||
})
|
||||
.select('-_id -__v -user')
|
||||
.lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
},
|
||||
getPromptGroup: async (filter) => {
|
||||
try {
|
||||
return await PromptGroup.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt group', error);
|
||||
return { message: 'Error getting prompt group' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Deletes a prompt and its corresponding prompt group if it is the last prompt in the group.
|
||||
*
|
||||
* @param {Object} options - The options for deleting the prompt.
|
||||
* @param {ObjectId|string} options.promptId - The ID of the prompt to delete.
|
||||
* @param {ObjectId|string} options.groupId - The ID of the prompt's group.
|
||||
* @param {ObjectId|string} options.author - The ID of the prompt's author.
|
||||
* @param {string} options.role - The role of the prompt's author.
|
||||
* @return {Promise<TDeletePromptResponse>} An object containing the result of the deletion.
|
||||
* If the prompt was deleted successfully, the object will have a property 'prompt' with the value 'Prompt deleted successfully'.
|
||||
* If the prompt group was deleted successfully, the object will have a property 'promptGroup' with the message 'Prompt group deleted successfully' and id of the deleted group.
|
||||
* If there was an error deleting the prompt, the object will have a property 'message' with the value 'Error deleting prompt'.
|
||||
*/
|
||||
deletePrompt: async ({ promptId, groupId, author, role }) => {
|
||||
const query = { _id: promptId, groupId, author };
|
||||
if (role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const { deletedCount } = await Prompt.deleteOne(query);
|
||||
if (deletedCount === 0) {
|
||||
throw new Error('Failed to delete the prompt');
|
||||
}
|
||||
|
||||
const remainingPrompts = await Prompt.find({ groupId })
|
||||
.select('_id')
|
||||
.sort({ createdAt: 1 })
|
||||
.lean();
|
||||
|
||||
if (remainingPrompts.length === 0) {
|
||||
// Remove all ACL entries for the promptGroup when deleting the last prompt
|
||||
try {
|
||||
await removeAllPermissions({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: groupId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error removing promptGroup permissions:', error);
|
||||
}
|
||||
|
||||
await PromptGroup.deleteOne({ _id: groupId });
|
||||
await removeGroupFromAllProjects(groupId);
|
||||
|
||||
return {
|
||||
prompt: 'Prompt deleted successfully',
|
||||
promptGroup: {
|
||||
message: 'Prompt group deleted successfully',
|
||||
id: groupId,
|
||||
},
|
||||
};
|
||||
} else {
|
||||
const promptGroup = await PromptGroup.findById(groupId).lean();
|
||||
if (promptGroup.productionId.toString() === promptId.toString()) {
|
||||
await PromptGroup.updateOne(
|
||||
{ _id: groupId },
|
||||
{ productionId: remainingPrompts[remainingPrompts.length - 1]._id },
|
||||
);
|
||||
}
|
||||
|
||||
return { prompt: 'Prompt deleted successfully' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Delete all prompts and prompt groups created by a specific user.
|
||||
* @param {ServerRequest} req - The server request object.
|
||||
* @param {string} userId - The ID of the user whose prompts and prompt groups are to be deleted.
|
||||
*/
|
||||
deleteUserPrompts: async (req, userId) => {
|
||||
try {
|
||||
const promptGroups = await getAllPromptGroups(req, { author: new ObjectId(userId) });
|
||||
|
||||
if (promptGroups.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const groupIds = promptGroups.map((group) => group._id);
|
||||
|
||||
for (const groupId of groupIds) {
|
||||
await removeGroupFromAllProjects(groupId);
|
||||
}
|
||||
|
||||
await AclEntry.deleteMany({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: { $in: groupIds },
|
||||
});
|
||||
|
||||
await PromptGroup.deleteMany({ author: new ObjectId(userId) });
|
||||
await Prompt.deleteMany({ author: new ObjectId(userId) });
|
||||
} catch (error) {
|
||||
logger.error('[deleteUserPrompts] General error:', error);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Update prompt group
|
||||
* @param {Partial<MongoPromptGroup>} filter - Filter to find prompt group
|
||||
* @param {Partial<MongoPromptGroup>} data - Data to update
|
||||
* @returns {Promise<TUpdatePromptGroupResponse>}
|
||||
*/
|
||||
updatePromptGroup: async (filter, data) => {
|
||||
try {
|
||||
const updateOps = {};
|
||||
if (data.removeProjectIds) {
|
||||
for (const projectId of data.removeProjectIds) {
|
||||
await removeGroupIdsFromProject(projectId, [filter._id]);
|
||||
}
|
||||
|
||||
updateOps.$pull = { projectIds: { $in: data.removeProjectIds } };
|
||||
delete data.removeProjectIds;
|
||||
}
|
||||
|
||||
if (data.projectIds) {
|
||||
for (const projectId of data.projectIds) {
|
||||
await addGroupIdsToProject(projectId, [filter._id]);
|
||||
}
|
||||
|
||||
updateOps.$addToSet = { projectIds: { $each: data.projectIds } };
|
||||
delete data.projectIds;
|
||||
}
|
||||
|
||||
const updateData = { ...data, ...updateOps };
|
||||
const updatedDoc = await PromptGroup.findOneAndUpdate(filter, updateData, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
});
|
||||
|
||||
if (!updatedDoc) {
|
||||
throw new Error('Prompt group not found');
|
||||
}
|
||||
|
||||
return updatedDoc;
|
||||
} catch (error) {
|
||||
logger.error('Error updating prompt group', error);
|
||||
return { message: 'Error updating prompt group' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Function to make a prompt production based on its ID.
|
||||
* @param {String} promptId - The ID of the prompt to make production.
|
||||
* @returns {Object} The result of the production operation.
|
||||
*/
|
||||
makePromptProduction: async (promptId) => {
|
||||
try {
|
||||
const prompt = await Prompt.findById(promptId).lean();
|
||||
|
||||
if (!prompt) {
|
||||
throw new Error('Prompt not found');
|
||||
}
|
||||
|
||||
await PromptGroup.findByIdAndUpdate(
|
||||
prompt.groupId,
|
||||
{ productionId: prompt._id },
|
||||
{ new: true },
|
||||
)
|
||||
.lean()
|
||||
.exec();
|
||||
|
||||
return {
|
||||
message: 'Prompt production made successfully',
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error making prompt production', error);
|
||||
return { message: 'Error making prompt production' };
|
||||
}
|
||||
},
|
||||
updatePromptLabels: async (_id, labels) => {
|
||||
try {
|
||||
const response = await Prompt.updateOne({ _id }, { $set: { labels } });
|
||||
if (response.matchedCount === 0) {
|
||||
return { message: 'Prompt not found' };
|
||||
}
|
||||
return { message: 'Prompt labels updated successfully' };
|
||||
} catch (error) {
|
||||
logger.error('Error updating prompt labels', error);
|
||||
return { message: 'Error updating prompt labels' };
|
||||
}
|
||||
},
|
||||
};
|
||||
|
|
@ -1,564 +0,0 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { ObjectId } = require('mongodb');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const {
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
AccessRoleIds,
|
||||
PrincipalType,
|
||||
PermissionBits,
|
||||
} = require('librechat-data-provider');
|
||||
|
||||
// Mock the config/connect module to prevent connection attempts during tests
|
||||
jest.mock('../../config/connect', () => jest.fn().mockResolvedValue(true));
|
||||
|
||||
const dbModels = require('~/db/models');
|
||||
|
||||
// Disable console for tests
|
||||
logger.silent = true;
|
||||
|
||||
let mongoServer;
|
||||
let Prompt, PromptGroup, AclEntry, AccessRole, User, Group, Project;
|
||||
let promptFns, permissionService;
|
||||
let testUsers, testGroups, testRoles;
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up MongoDB memory server
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
// Initialize models
|
||||
Prompt = dbModels.Prompt;
|
||||
PromptGroup = dbModels.PromptGroup;
|
||||
AclEntry = dbModels.AclEntry;
|
||||
AccessRole = dbModels.AccessRole;
|
||||
User = dbModels.User;
|
||||
Group = dbModels.Group;
|
||||
Project = dbModels.Project;
|
||||
|
||||
promptFns = require('~/models/Prompt');
|
||||
permissionService = require('~/server/services/PermissionService');
|
||||
|
||||
// Create test data
|
||||
await setupTestData();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
async function setupTestData() {
|
||||
// Create access roles for promptGroups
|
||||
testRoles = {
|
||||
viewer: await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
name: 'Viewer',
|
||||
description: 'Can view promptGroups',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits: PermissionBits.VIEW,
|
||||
}),
|
||||
editor: await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_EDITOR,
|
||||
name: 'Editor',
|
||||
description: 'Can view and edit promptGroups',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits: PermissionBits.VIEW | PermissionBits.EDIT,
|
||||
}),
|
||||
owner: await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
name: 'Owner',
|
||||
description: 'Full control over promptGroups',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits:
|
||||
PermissionBits.VIEW | PermissionBits.EDIT | PermissionBits.DELETE | PermissionBits.SHARE,
|
||||
}),
|
||||
};
|
||||
|
||||
// Create test users
|
||||
testUsers = {
|
||||
owner: await User.create({
|
||||
name: 'Prompt Owner',
|
||||
email: 'owner@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
editor: await User.create({
|
||||
name: 'Prompt Editor',
|
||||
email: 'editor@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
viewer: await User.create({
|
||||
name: 'Prompt Viewer',
|
||||
email: 'viewer@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
admin: await User.create({
|
||||
name: 'Admin User',
|
||||
email: 'admin@example.com',
|
||||
role: SystemRoles.ADMIN,
|
||||
}),
|
||||
noAccess: await User.create({
|
||||
name: 'No Access User',
|
||||
email: 'noaccess@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
};
|
||||
|
||||
// Create test groups
|
||||
testGroups = {
|
||||
editors: await Group.create({
|
||||
name: 'Prompt Editors',
|
||||
description: 'Group with editor access',
|
||||
}),
|
||||
viewers: await Group.create({
|
||||
name: 'Prompt Viewers',
|
||||
description: 'Group with viewer access',
|
||||
}),
|
||||
};
|
||||
|
||||
await Project.create({
|
||||
name: 'Global',
|
||||
description: 'Global project',
|
||||
promptGroupIds: [],
|
||||
});
|
||||
}
|
||||
|
||||
describe('Prompt ACL Permissions', () => {
|
||||
describe('Creating Prompts with Permissions', () => {
|
||||
it('should grant owner permissions when creating a prompt', async () => {
|
||||
// First create a group
|
||||
const testGroup = await PromptGroup.create({
|
||||
name: 'Test Group',
|
||||
category: 'testing',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new mongoose.Types.ObjectId(),
|
||||
});
|
||||
|
||||
const promptData = {
|
||||
prompt: {
|
||||
prompt: 'Test prompt content',
|
||||
name: 'Test Prompt',
|
||||
type: 'text',
|
||||
groupId: testGroup._id,
|
||||
},
|
||||
author: testUsers.owner._id,
|
||||
};
|
||||
|
||||
await promptFns.savePrompt(promptData);
|
||||
|
||||
// Manually grant permissions as would happen in the route
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Check ACL entry
|
||||
const aclEntry = await AclEntry.findOne({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
});
|
||||
|
||||
expect(aclEntry).toBeTruthy();
|
||||
expect(aclEntry.permBits).toBe(testRoles.owner.permBits);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Accessing Prompts', () => {
|
||||
let testPromptGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create a prompt group
|
||||
testPromptGroup = await PromptGroup.create({
|
||||
name: 'Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Create a prompt
|
||||
await Prompt.create({
|
||||
prompt: 'Test prompt for access control',
|
||||
name: 'Access Test Prompt',
|
||||
author: testUsers.owner._id,
|
||||
groupId: testPromptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
// Grant owner permissions
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('owner should have full access to their prompt', async () => {
|
||||
const hasAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(true);
|
||||
|
||||
const canEdit = await permissionService.checkPermission({
|
||||
userId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
expect(canEdit).toBe(true);
|
||||
});
|
||||
|
||||
it('user with viewer role should only have view access', async () => {
|
||||
// Grant viewer permissions
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
const canView = await permissionService.checkPermission({
|
||||
userId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
const canEdit = await permissionService.checkPermission({
|
||||
userId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
expect(canView).toBe(true);
|
||||
expect(canEdit).toBe(false);
|
||||
});
|
||||
|
||||
it('user without permissions should have no access', async () => {
|
||||
const hasAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.noAccess._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(false);
|
||||
});
|
||||
|
||||
it('admin should have access regardless of permissions', async () => {
|
||||
// Admin users should work through normal permission system
|
||||
// The middleware layer handles admin bypass, not the permission service
|
||||
const hasAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.admin._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
// Without explicit permissions, even admin won't have access at this layer
|
||||
expect(hasAccess).toBe(false);
|
||||
|
||||
// The actual admin bypass happens in the middleware layer (`canAccessPromptViaGroup`/`canAccessPromptGroupResource`)
|
||||
// which checks req.user.role === SystemRoles.ADMIN
|
||||
});
|
||||
});
|
||||
|
||||
describe('Group-based Access', () => {
|
||||
let testPromptGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create a prompt group first
|
||||
testPromptGroup = await PromptGroup.create({
|
||||
name: 'Group Access Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
await Prompt.create({
|
||||
prompt: 'Group access test prompt',
|
||||
name: 'Group Test',
|
||||
author: testUsers.owner._id,
|
||||
groupId: testPromptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
// Add users to groups
|
||||
await User.findByIdAndUpdate(testUsers.editor._id, {
|
||||
$push: { groups: testGroups.editors._id },
|
||||
});
|
||||
|
||||
await User.findByIdAndUpdate(testUsers.viewer._id, {
|
||||
$push: { groups: testGroups.viewers._id },
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
await User.updateMany({}, { $set: { groups: [] } });
|
||||
});
|
||||
|
||||
it('group members should inherit group permissions', async () => {
|
||||
// Create a prompt group
|
||||
const testPromptGroup = await PromptGroup.create({
|
||||
name: 'Group Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
const { addUserToGroup } = require('~/models');
|
||||
await addUserToGroup(testUsers.editor._id, testGroups.editors._id);
|
||||
|
||||
const prompt = await promptFns.savePrompt({
|
||||
author: testUsers.owner._id,
|
||||
prompt: {
|
||||
prompt: 'Group test prompt',
|
||||
name: 'Group Test',
|
||||
groupId: testPromptGroup._id,
|
||||
type: 'text',
|
||||
},
|
||||
});
|
||||
|
||||
// Check if savePrompt returned an error
|
||||
if (!prompt || !prompt.prompt) {
|
||||
throw new Error(`Failed to save prompt: ${prompt?.message || 'Unknown error'}`);
|
||||
}
|
||||
|
||||
// Grant edit permissions to the group
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.GROUP,
|
||||
principalId: testGroups.editors._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_EDITOR,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Check if group member has access
|
||||
const hasAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.editor._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(true);
|
||||
|
||||
// Check that non-member doesn't have access
|
||||
const nonMemberAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
expect(nonMemberAccess).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Public Access', () => {
|
||||
let publicPromptGroup, privatePromptGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create separate prompt groups for public and private access
|
||||
publicPromptGroup = await PromptGroup.create({
|
||||
name: 'Public Access Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
privatePromptGroup = await PromptGroup.create({
|
||||
name: 'Private Access Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Create prompts in their respective groups
|
||||
await Prompt.create({
|
||||
prompt: 'Public prompt',
|
||||
name: 'Public',
|
||||
author: testUsers.owner._id,
|
||||
groupId: publicPromptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
await Prompt.create({
|
||||
prompt: 'Private prompt',
|
||||
name: 'Private',
|
||||
author: testUsers.owner._id,
|
||||
groupId: privatePromptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
// Grant public view access to publicPromptGroup
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.PUBLIC,
|
||||
principalId: null,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: publicPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Grant only owner access to privatePromptGroup
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: privatePromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('public prompt should be accessible to any user', async () => {
|
||||
const hasAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.noAccess._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: publicPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
includePublic: true,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(true);
|
||||
});
|
||||
|
||||
it('private prompt should not be accessible to unauthorized users', async () => {
|
||||
const hasAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.noAccess._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: privatePromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
includePublic: true,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Prompt Deletion', () => {
|
||||
let testPromptGroup;
|
||||
|
||||
it('should remove ACL entries when prompt is deleted', async () => {
|
||||
testPromptGroup = await PromptGroup.create({
|
||||
name: 'Deletion Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
const prompt = await promptFns.savePrompt({
|
||||
author: testUsers.owner._id,
|
||||
prompt: {
|
||||
prompt: 'To be deleted',
|
||||
name: 'Delete Test',
|
||||
groupId: testPromptGroup._id,
|
||||
type: 'text',
|
||||
},
|
||||
});
|
||||
|
||||
// Check if savePrompt returned an error
|
||||
if (!prompt || !prompt.prompt) {
|
||||
throw new Error(`Failed to save prompt: ${prompt?.message || 'Unknown error'}`);
|
||||
}
|
||||
|
||||
const testPromptId = prompt.prompt._id;
|
||||
const promptGroupId = testPromptGroup._id;
|
||||
|
||||
// Grant permission
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Verify ACL entry exists
|
||||
const beforeDelete = await AclEntry.find({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
});
|
||||
expect(beforeDelete).toHaveLength(1);
|
||||
|
||||
// Delete the prompt
|
||||
await promptFns.deletePrompt({
|
||||
promptId: testPromptId,
|
||||
groupId: promptGroupId,
|
||||
author: testUsers.owner._id,
|
||||
role: SystemRoles.USER,
|
||||
});
|
||||
|
||||
// Verify ACL entries are removed
|
||||
const aclEntries = await AclEntry.find({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
});
|
||||
|
||||
expect(aclEntries).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Backwards Compatibility', () => {
|
||||
it('should handle prompts without ACL entries gracefully', async () => {
|
||||
// Create a prompt group first
|
||||
const promptGroup = await PromptGroup.create({
|
||||
name: 'Legacy Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Create a prompt without ACL entries (legacy prompt)
|
||||
const legacyPrompt = await Prompt.create({
|
||||
prompt: 'Legacy prompt without ACL',
|
||||
name: 'Legacy',
|
||||
author: testUsers.owner._id,
|
||||
groupId: promptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
// The system should handle this gracefully
|
||||
const prompt = await promptFns.getPrompt({ _id: legacyPrompt._id });
|
||||
expect(prompt).toBeTruthy();
|
||||
expect(prompt._id.toString()).toBe(legacyPrompt._id.toString());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,304 +0,0 @@
|
|||
const {
|
||||
CacheKeys,
|
||||
SystemRoles,
|
||||
roleDefaults,
|
||||
permissionsSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { Role } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* Retrieve a role by name and convert the found role document to a plain object.
|
||||
* If the role with the given name doesn't exist and the name is a system defined role,
|
||||
* create it and return the lean version.
|
||||
*
|
||||
* @param {string} roleName - The name of the role to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<IRole>} Role document.
|
||||
*/
|
||||
const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
try {
|
||||
const cachedRole = await cache.get(roleName);
|
||||
if (cachedRole) {
|
||||
return cachedRole;
|
||||
}
|
||||
let query = Role.findOne({ name: roleName });
|
||||
if (fieldsToSelect) {
|
||||
query = query.select(fieldsToSelect);
|
||||
}
|
||||
let role = await query.lean().exec();
|
||||
|
||||
if (!role && SystemRoles[roleName]) {
|
||||
role = await new Role(roleDefaults[roleName]).save();
|
||||
await cache.set(roleName, role);
|
||||
return role.toObject();
|
||||
}
|
||||
await cache.set(roleName, role);
|
||||
return role;
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to retrieve or create role: ${error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Update role values by name.
|
||||
*
|
||||
* @param {string} roleName - The name of the role to update.
|
||||
* @param {Partial<TRole>} updates - The fields to update.
|
||||
* @returns {Promise<TRole>} Updated role document.
|
||||
*/
|
||||
const updateRoleByName = async function (roleName, updates) {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
try {
|
||||
const role = await Role.findOneAndUpdate(
|
||||
{ name: roleName },
|
||||
{ $set: updates },
|
||||
{ new: true, lean: true },
|
||||
)
|
||||
.select('-__v')
|
||||
.lean()
|
||||
.exec();
|
||||
await cache.set(roleName, role);
|
||||
return role;
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to update role: ${error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates access permissions for a specific role and multiple permission types.
|
||||
* @param {string} roleName - The role to update.
|
||||
* @param {Object.<PermissionTypes, Object.<Permissions, boolean>>} permissionsUpdate - Permissions to update and their values.
|
||||
* @param {IRole} [roleData] - Optional role data to use instead of fetching from the database.
|
||||
*/
|
||||
async function updateAccessPermissions(roleName, permissionsUpdate, roleData) {
|
||||
// Filter and clean the permission updates based on our schema definition.
|
||||
const updates = {};
|
||||
for (const [permissionType, permissions] of Object.entries(permissionsUpdate)) {
|
||||
if (permissionsSchema.shape && permissionsSchema.shape[permissionType]) {
|
||||
updates[permissionType] = removeNullishValues(permissions);
|
||||
}
|
||||
}
|
||||
if (!Object.keys(updates).length) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const role = roleData ?? (await getRoleByName(roleName));
|
||||
if (!role) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentPermissions = role.permissions || {};
|
||||
const updatedPermissions = { ...currentPermissions };
|
||||
let hasChanges = false;
|
||||
|
||||
const unsetFields = {};
|
||||
const permissionTypes = Object.keys(permissionsSchema.shape || {});
|
||||
for (const permType of permissionTypes) {
|
||||
if (role[permType] && typeof role[permType] === 'object') {
|
||||
logger.info(
|
||||
`Migrating '${roleName}' role from old schema: found '${permType}' at top level`,
|
||||
);
|
||||
|
||||
updatedPermissions[permType] = {
|
||||
...updatedPermissions[permType],
|
||||
...role[permType],
|
||||
};
|
||||
|
||||
unsetFields[permType] = 1;
|
||||
hasChanges = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate legacy SHARED_GLOBAL → SHARE for PROMPTS and AGENTS.
|
||||
// SHARED_GLOBAL was removed in favour of SHARE in PR #11283. If the DB still has
|
||||
// SHARED_GLOBAL but not SHARE, inherit the value so sharing intent is preserved.
|
||||
const legacySharedGlobalTypes = ['PROMPTS', 'AGENTS'];
|
||||
for (const legacyPermType of legacySharedGlobalTypes) {
|
||||
const existingTypePerms = currentPermissions[legacyPermType];
|
||||
if (
|
||||
existingTypePerms &&
|
||||
'SHARED_GLOBAL' in existingTypePerms &&
|
||||
!('SHARE' in existingTypePerms) &&
|
||||
updates[legacyPermType] &&
|
||||
// Don't override an explicit SHARE value the caller already provided
|
||||
!('SHARE' in updates[legacyPermType])
|
||||
) {
|
||||
const inheritedValue = existingTypePerms['SHARED_GLOBAL'];
|
||||
updates[legacyPermType]['SHARE'] = inheritedValue;
|
||||
logger.info(
|
||||
`Migrating '${roleName}' role ${legacyPermType}.SHARED_GLOBAL=${inheritedValue} → SHARE`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (const [permissionType, permissions] of Object.entries(updates)) {
|
||||
const currentTypePermissions = currentPermissions[permissionType] || {};
|
||||
updatedPermissions[permissionType] = { ...currentTypePermissions };
|
||||
|
||||
for (const [permission, value] of Object.entries(permissions)) {
|
||||
if (currentTypePermissions[permission] !== value) {
|
||||
updatedPermissions[permissionType][permission] = value;
|
||||
hasChanges = true;
|
||||
logger.info(
|
||||
`Updating '${roleName}' role permission '${permissionType}' '${permission}' from ${currentTypePermissions[permission]} to: ${value}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up orphaned SHARED_GLOBAL fields left in DB after the schema rename.
|
||||
// Since we $set the full permissions object, deleting from updatedPermissions
|
||||
// is sufficient to remove the field from MongoDB.
|
||||
for (const legacyPermType of legacySharedGlobalTypes) {
|
||||
const existingTypePerms = currentPermissions[legacyPermType];
|
||||
if (existingTypePerms && 'SHARED_GLOBAL' in existingTypePerms) {
|
||||
if (!updates[legacyPermType]) {
|
||||
// permType wasn't in the update payload so the migration block above didn't run.
|
||||
// Create a writable copy and handle the SHARED_GLOBAL → SHARE inheritance here
|
||||
// to avoid removing SHARED_GLOBAL without writing SHARE (data loss).
|
||||
updatedPermissions[legacyPermType] = { ...existingTypePerms };
|
||||
if (!('SHARE' in existingTypePerms)) {
|
||||
updatedPermissions[legacyPermType]['SHARE'] = existingTypePerms['SHARED_GLOBAL'];
|
||||
logger.info(
|
||||
`Migrating '${roleName}' role ${legacyPermType}.SHARED_GLOBAL=${existingTypePerms['SHARED_GLOBAL']} → SHARE`,
|
||||
);
|
||||
}
|
||||
}
|
||||
delete updatedPermissions[legacyPermType]['SHARED_GLOBAL'];
|
||||
hasChanges = true;
|
||||
logger.info(
|
||||
`Removed legacy SHARED_GLOBAL field from '${roleName}' role ${legacyPermType} permissions`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (hasChanges) {
|
||||
const updateObj = { permissions: updatedPermissions };
|
||||
|
||||
if (Object.keys(unsetFields).length > 0) {
|
||||
logger.info(
|
||||
`Unsetting old schema fields for '${roleName}' role: ${Object.keys(unsetFields).join(', ')}`,
|
||||
);
|
||||
|
||||
try {
|
||||
await Role.updateOne(
|
||||
{ name: roleName },
|
||||
{
|
||||
$set: updateObj,
|
||||
$unset: unsetFields,
|
||||
},
|
||||
);
|
||||
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const updatedRole = await Role.findOne({ name: roleName }).select('-__v').lean().exec();
|
||||
await cache.set(roleName, updatedRole);
|
||||
|
||||
logger.info(`Updated role '${roleName}' and removed old schema fields`);
|
||||
} catch (updateError) {
|
||||
logger.error(`Error during role migration update: ${updateError.message}`);
|
||||
throw updateError;
|
||||
}
|
||||
} else {
|
||||
// Standard update if no migration needed
|
||||
await updateRoleByName(roleName, updateObj);
|
||||
}
|
||||
|
||||
logger.info(`Updated '${roleName}' role permissions`);
|
||||
} else {
|
||||
logger.info(`No changes needed for '${roleName}' role permissions`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to update ${roleName} role permissions:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Migrates roles from old schema to new schema structure.
|
||||
* This can be called directly to fix existing roles.
|
||||
*
|
||||
* @param {string} [roleName] - Optional specific role to migrate. If not provided, migrates all roles.
|
||||
* @returns {Promise<number>} Number of roles migrated.
|
||||
*/
|
||||
const migrateRoleSchema = async function (roleName) {
|
||||
try {
|
||||
// Get roles to migrate
|
||||
let roles;
|
||||
if (roleName) {
|
||||
const role = await Role.findOne({ name: roleName });
|
||||
roles = role ? [role] : [];
|
||||
} else {
|
||||
roles = await Role.find({});
|
||||
}
|
||||
|
||||
logger.info(`Migrating ${roles.length} roles to new schema structure`);
|
||||
let migratedCount = 0;
|
||||
|
||||
for (const role of roles) {
|
||||
const permissionTypes = Object.keys(permissionsSchema.shape || {});
|
||||
const unsetFields = {};
|
||||
let hasOldSchema = false;
|
||||
|
||||
// Check for old schema fields
|
||||
for (const permType of permissionTypes) {
|
||||
if (role[permType] && typeof role[permType] === 'object') {
|
||||
hasOldSchema = true;
|
||||
|
||||
// Ensure permissions object exists
|
||||
role.permissions = role.permissions || {};
|
||||
|
||||
// Migrate permissions from old location to new
|
||||
role.permissions[permType] = {
|
||||
...role.permissions[permType],
|
||||
...role[permType],
|
||||
};
|
||||
|
||||
// Mark field for removal
|
||||
unsetFields[permType] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasOldSchema) {
|
||||
try {
|
||||
logger.info(`Migrating role '${role.name}' from old schema structure`);
|
||||
|
||||
// Simple update operation
|
||||
await Role.updateOne(
|
||||
{ _id: role._id },
|
||||
{
|
||||
$set: { permissions: role.permissions },
|
||||
$unset: unsetFields,
|
||||
},
|
||||
);
|
||||
|
||||
// Refresh cache
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const updatedRole = await Role.findById(role._id).lean().exec();
|
||||
await cache.set(role.name, updatedRole);
|
||||
|
||||
migratedCount++;
|
||||
logger.info(`Migrated role '${role.name}'`);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to migrate role '${role.name}': ${error.message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`Migration complete: ${migratedCount} roles migrated`);
|
||||
return migratedCount;
|
||||
} catch (error) {
|
||||
logger.error(`Role schema migration failed: ${error.message}`);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getRoleByName,
|
||||
updateRoleByName,
|
||||
migrateRoleSchema,
|
||||
updateAccessPermissions,
|
||||
};
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
const { ToolCall } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* Create a new tool call
|
||||
* @param {IToolCallData} toolCallData - The tool call data
|
||||
* @returns {Promise<IToolCallData>} The created tool call document
|
||||
*/
|
||||
async function createToolCall(toolCallData) {
|
||||
try {
|
||||
return await ToolCall.create(toolCallData);
|
||||
} catch (error) {
|
||||
throw new Error(`Error creating tool call: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a tool call by ID
|
||||
* @param {string} id - The tool call document ID
|
||||
* @returns {Promise<IToolCallData|null>} The tool call document or null if not found
|
||||
*/
|
||||
async function getToolCallById(id) {
|
||||
try {
|
||||
return await ToolCall.findById(id).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error fetching tool call: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get tool calls by message ID and user
|
||||
* @param {string} messageId - The message ID
|
||||
* @param {string} userId - The user's ObjectId
|
||||
* @returns {Promise<Array>} Array of tool call documents
|
||||
*/
|
||||
async function getToolCallsByMessage(messageId, userId) {
|
||||
try {
|
||||
return await ToolCall.find({ messageId, user: userId }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error fetching tool calls: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get tool calls by conversation ID and user
|
||||
* @param {string} conversationId - The conversation ID
|
||||
* @param {string} userId - The user's ObjectId
|
||||
* @returns {Promise<IToolCallData[]>} Array of tool call documents
|
||||
*/
|
||||
async function getToolCallsByConvo(conversationId, userId) {
|
||||
try {
|
||||
return await ToolCall.find({ conversationId, user: userId }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error fetching tool calls: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a tool call
|
||||
* @param {string} id - The tool call document ID
|
||||
* @param {Partial<IToolCallData>} updateData - The data to update
|
||||
* @returns {Promise<IToolCallData|null>} The updated tool call document or null if not found
|
||||
*/
|
||||
async function updateToolCall(id, updateData) {
|
||||
try {
|
||||
return await ToolCall.findByIdAndUpdate(id, updateData, { new: true }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error updating tool call: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a tool call
|
||||
* @param {string} userId - The related user's ObjectId
|
||||
* @param {string} [conversationId] - The tool call conversation ID
|
||||
* @returns {Promise<{ ok?: number; n?: number; deletedCount?: number }>} The result of the delete operation
|
||||
*/
|
||||
async function deleteToolCalls(userId, conversationId) {
|
||||
try {
|
||||
const query = { user: userId };
|
||||
if (conversationId) {
|
||||
query.conversationId = conversationId;
|
||||
}
|
||||
return await ToolCall.deleteMany(query);
|
||||
} catch (error) {
|
||||
throw new Error(`Error deleting tool call: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
createToolCall,
|
||||
updateToolCall,
|
||||
deleteToolCalls,
|
||||
getToolCallById,
|
||||
getToolCallsByConvo,
|
||||
getToolCallsByMessage,
|
||||
};
|
||||
|
|
@ -1,223 +0,0 @@
|
|||
const { logger, CANCEL_RATE } = require('@librechat/data-schemas');
|
||||
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
||||
const { Transaction } = require('~/db/models');
|
||||
const { updateBalance } = require('~/models');
|
||||
|
||||
/** Method to calculate and set the tokenValue for a transaction */
|
||||
function calculateTokenValue(txn) {
|
||||
const { valueKey, tokenType, model, endpointTokenConfig, inputTokenCount } = txn;
|
||||
const multiplier = Math.abs(
|
||||
getMultiplier({ valueKey, tokenType, model, endpointTokenConfig, inputTokenCount }),
|
||||
);
|
||||
txn.rate = multiplier;
|
||||
txn.tokenValue = txn.rawAmount * multiplier;
|
||||
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') {
|
||||
txn.tokenValue = Math.ceil(txn.tokenValue * CANCEL_RATE);
|
||||
txn.rate *= CANCEL_RATE;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* New static method to create an auto-refill transaction that does NOT trigger a balance update.
|
||||
* @param {object} txData - Transaction data.
|
||||
* @param {string} txData.user - The user ID.
|
||||
* @param {string} txData.tokenType - The type of token.
|
||||
* @param {string} txData.context - The context of the transaction.
|
||||
* @param {number} txData.rawAmount - The raw amount of tokens.
|
||||
* @returns {Promise<object>} - The created transaction.
|
||||
*/
|
||||
async function createAutoRefillTransaction(txData) {
|
||||
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
|
||||
return;
|
||||
}
|
||||
const transaction = new Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
transaction.inputTokenCount = txData.inputTokenCount;
|
||||
calculateTokenValue(transaction);
|
||||
await transaction.save();
|
||||
|
||||
const balanceResponse = await updateBalance({
|
||||
user: transaction.user,
|
||||
incrementValue: txData.rawAmount,
|
||||
setValues: { lastRefill: new Date() },
|
||||
});
|
||||
const result = {
|
||||
rate: transaction.rate,
|
||||
user: transaction.user.toString(),
|
||||
balance: balanceResponse.tokenCredits,
|
||||
};
|
||||
logger.debug('[Balance.check] Auto-refill performed', result);
|
||||
result.transaction = transaction;
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Static method to create a transaction and update the balance
|
||||
* @param {txData} _txData - Transaction data.
|
||||
*/
|
||||
async function createTransaction(_txData) {
|
||||
const { balance, transactions, ...txData } = _txData;
|
||||
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (transactions?.enabled === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const transaction = new Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
transaction.inputTokenCount = txData.inputTokenCount;
|
||||
calculateTokenValue(transaction);
|
||||
|
||||
await transaction.save();
|
||||
if (!balance?.enabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
let incrementValue = transaction.tokenValue;
|
||||
const balanceResponse = await updateBalance({
|
||||
user: transaction.user,
|
||||
incrementValue,
|
||||
});
|
||||
|
||||
return {
|
||||
rate: transaction.rate,
|
||||
user: transaction.user.toString(),
|
||||
balance: balanceResponse.tokenCredits,
|
||||
[transaction.tokenType]: incrementValue,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Static method to create a structured transaction and update the balance
|
||||
* @param {txData} _txData - Transaction data.
|
||||
*/
|
||||
async function createStructuredTransaction(_txData) {
|
||||
const { balance, transactions, ...txData } = _txData;
|
||||
if (transactions?.enabled === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const transaction = new Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
transaction.inputTokenCount = txData.inputTokenCount;
|
||||
|
||||
calculateStructuredTokenValue(transaction);
|
||||
|
||||
await transaction.save();
|
||||
|
||||
if (!balance?.enabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
let incrementValue = transaction.tokenValue;
|
||||
|
||||
const balanceResponse = await updateBalance({
|
||||
user: transaction.user,
|
||||
incrementValue,
|
||||
});
|
||||
|
||||
return {
|
||||
rate: transaction.rate,
|
||||
user: transaction.user.toString(),
|
||||
balance: balanceResponse.tokenCredits,
|
||||
[transaction.tokenType]: incrementValue,
|
||||
};
|
||||
}
|
||||
|
||||
/** Method to calculate token value for structured tokens */
|
||||
function calculateStructuredTokenValue(txn) {
|
||||
if (!txn.tokenType) {
|
||||
txn.tokenValue = txn.rawAmount;
|
||||
return;
|
||||
}
|
||||
|
||||
const { model, endpointTokenConfig, inputTokenCount } = txn;
|
||||
|
||||
if (txn.tokenType === 'prompt') {
|
||||
const inputMultiplier = getMultiplier({
|
||||
tokenType: 'prompt',
|
||||
model,
|
||||
endpointTokenConfig,
|
||||
inputTokenCount,
|
||||
});
|
||||
const writeMultiplier =
|
||||
getCacheMultiplier({ cacheType: 'write', model, endpointTokenConfig }) ?? inputMultiplier;
|
||||
const readMultiplier =
|
||||
getCacheMultiplier({ cacheType: 'read', model, endpointTokenConfig }) ?? inputMultiplier;
|
||||
|
||||
txn.rateDetail = {
|
||||
input: inputMultiplier,
|
||||
write: writeMultiplier,
|
||||
read: readMultiplier,
|
||||
};
|
||||
|
||||
const totalPromptTokens =
|
||||
Math.abs(txn.inputTokens || 0) +
|
||||
Math.abs(txn.writeTokens || 0) +
|
||||
Math.abs(txn.readTokens || 0);
|
||||
|
||||
if (totalPromptTokens > 0) {
|
||||
txn.rate =
|
||||
(Math.abs(inputMultiplier * (txn.inputTokens || 0)) +
|
||||
Math.abs(writeMultiplier * (txn.writeTokens || 0)) +
|
||||
Math.abs(readMultiplier * (txn.readTokens || 0))) /
|
||||
totalPromptTokens;
|
||||
} else {
|
||||
txn.rate = Math.abs(inputMultiplier); // Default to input rate if no tokens
|
||||
}
|
||||
|
||||
txn.tokenValue = -(
|
||||
Math.abs(txn.inputTokens || 0) * inputMultiplier +
|
||||
Math.abs(txn.writeTokens || 0) * writeMultiplier +
|
||||
Math.abs(txn.readTokens || 0) * readMultiplier
|
||||
);
|
||||
|
||||
txn.rawAmount = -totalPromptTokens;
|
||||
} else if (txn.tokenType === 'completion') {
|
||||
const multiplier = getMultiplier({
|
||||
tokenType: txn.tokenType,
|
||||
model,
|
||||
endpointTokenConfig,
|
||||
inputTokenCount,
|
||||
});
|
||||
txn.rate = Math.abs(multiplier);
|
||||
txn.tokenValue = -Math.abs(txn.rawAmount) * multiplier;
|
||||
txn.rawAmount = -Math.abs(txn.rawAmount);
|
||||
}
|
||||
|
||||
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') {
|
||||
txn.tokenValue = Math.ceil(txn.tokenValue * CANCEL_RATE);
|
||||
txn.rate *= CANCEL_RATE;
|
||||
if (txn.rateDetail) {
|
||||
txn.rateDetail = Object.fromEntries(
|
||||
Object.entries(txn.rateDetail).map(([k, v]) => [k, v * CANCEL_RATE]),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Queries and retrieves transactions based on a given filter.
|
||||
* @async
|
||||
* @function getTransactions
|
||||
* @param {Object} filter - MongoDB filter object to apply when querying transactions.
|
||||
* @returns {Promise<Array>} A promise that resolves to an array of matched transactions.
|
||||
* @throws {Error} Throws an error if querying the database fails.
|
||||
*/
|
||||
async function getTransactions(filter) {
|
||||
try {
|
||||
return await Transaction.find(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error querying transactions:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getTransactions,
|
||||
createTransaction,
|
||||
createAutoRefillTransaction,
|
||||
createStructuredTransaction,
|
||||
};
|
||||
|
|
@ -1,156 +0,0 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { createAutoRefillTransaction } = require('./Transaction');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { getMultiplier } = require('./tx');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
function isInvalidDate(date) {
|
||||
return isNaN(date);
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple check method that calculates token cost and returns balance info.
|
||||
* The auto-refill logic has been moved to balanceMethods.js to prevent circular dependencies.
|
||||
*/
|
||||
const checkBalanceRecord = async function ({
|
||||
user,
|
||||
model,
|
||||
endpoint,
|
||||
valueKey,
|
||||
tokenType,
|
||||
amount,
|
||||
endpointTokenConfig,
|
||||
}) {
|
||||
const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig });
|
||||
const tokenCost = amount * multiplier;
|
||||
|
||||
// Retrieve the balance record
|
||||
let record = await Balance.findOne({ user }).lean();
|
||||
if (!record) {
|
||||
logger.debug('[Balance.check] No balance record found for user', { user });
|
||||
return {
|
||||
canSpend: false,
|
||||
balance: 0,
|
||||
tokenCost,
|
||||
};
|
||||
}
|
||||
let balance = record.tokenCredits;
|
||||
|
||||
logger.debug('[Balance.check] Initial state', {
|
||||
user,
|
||||
model,
|
||||
endpoint,
|
||||
valueKey,
|
||||
tokenType,
|
||||
amount,
|
||||
balance,
|
||||
multiplier,
|
||||
endpointTokenConfig: !!endpointTokenConfig,
|
||||
});
|
||||
|
||||
// Only perform auto-refill if spending would bring the balance to 0 or below
|
||||
if (balance - tokenCost <= 0 && record.autoRefillEnabled && record.refillAmount > 0) {
|
||||
const lastRefillDate = new Date(record.lastRefill);
|
||||
const now = new Date();
|
||||
if (
|
||||
isInvalidDate(lastRefillDate) ||
|
||||
now >=
|
||||
addIntervalToDate(lastRefillDate, record.refillIntervalValue, record.refillIntervalUnit)
|
||||
) {
|
||||
try {
|
||||
/** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */
|
||||
const result = await createAutoRefillTransaction({
|
||||
user: user,
|
||||
tokenType: 'credits',
|
||||
context: 'autoRefill',
|
||||
rawAmount: record.refillAmount,
|
||||
});
|
||||
balance = result.balance;
|
||||
} catch (error) {
|
||||
logger.error('[Balance.check] Failed to record transaction for auto-refill', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('[Balance.check] Token cost', { tokenCost });
|
||||
return { canSpend: balance >= tokenCost, balance, tokenCost };
|
||||
};
|
||||
|
||||
/**
|
||||
* Adds a time interval to a given date.
|
||||
* @param {Date} date - The starting date.
|
||||
* @param {number} value - The numeric value of the interval.
|
||||
* @param {'seconds'|'minutes'|'hours'|'days'|'weeks'|'months'} unit - The unit of time.
|
||||
* @returns {Date} A new Date representing the starting date plus the interval.
|
||||
*/
|
||||
const addIntervalToDate = (date, value, unit) => {
|
||||
const result = new Date(date);
|
||||
switch (unit) {
|
||||
case 'seconds':
|
||||
result.setSeconds(result.getSeconds() + value);
|
||||
break;
|
||||
case 'minutes':
|
||||
result.setMinutes(result.getMinutes() + value);
|
||||
break;
|
||||
case 'hours':
|
||||
result.setHours(result.getHours() + value);
|
||||
break;
|
||||
case 'days':
|
||||
result.setDate(result.getDate() + value);
|
||||
break;
|
||||
case 'weeks':
|
||||
result.setDate(result.getDate() + value * 7);
|
||||
break;
|
||||
case 'months':
|
||||
result.setMonth(result.getMonth() + value);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks the balance for a user and determines if they can spend a certain amount.
|
||||
* If the user cannot spend the amount, it logs a violation and denies the request.
|
||||
*
|
||||
* @async
|
||||
* @function
|
||||
* @param {Object} params - The function parameters.
|
||||
* @param {ServerRequest} params.req - The Express request object.
|
||||
* @param {Express.Response} params.res - The Express response object.
|
||||
* @param {Object} params.txData - The transaction data.
|
||||
* @param {string} params.txData.user - The user ID or identifier.
|
||||
* @param {('prompt' | 'completion')} params.txData.tokenType - The type of token.
|
||||
* @param {number} params.txData.amount - The amount of tokens.
|
||||
* @param {string} params.txData.model - The model name or identifier.
|
||||
* @param {string} [params.txData.endpointTokenConfig] - The token configuration for the endpoint.
|
||||
* @returns {Promise<boolean>} Throws error if the user cannot spend the amount.
|
||||
* @throws {Error} Throws an error if there's an issue with the balance check.
|
||||
*/
|
||||
const checkBalance = async ({ req, res, txData }) => {
|
||||
const { canSpend, balance, tokenCost } = await checkBalanceRecord(txData);
|
||||
if (canSpend) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const type = ViolationTypes.TOKEN_BALANCE;
|
||||
const errorMessage = {
|
||||
type,
|
||||
balance,
|
||||
tokenCost,
|
||||
promptTokens: txData.amount,
|
||||
};
|
||||
|
||||
if (txData.generations && txData.generations.length > 0) {
|
||||
errorMessage.generations = txData.generations;
|
||||
}
|
||||
|
||||
await logViolation(req, res, type, errorMessage, 0);
|
||||
throw new Error(JSON.stringify(errorMessage));
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
checkBalance,
|
||||
};
|
||||
|
|
@ -1,48 +1,22 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { createMethods } = require('@librechat/data-schemas');
|
||||
const methods = createMethods(mongoose);
|
||||
const { comparePassword } = require('./userMethods');
|
||||
const {
|
||||
getMessage,
|
||||
getMessages,
|
||||
saveMessage,
|
||||
recordMessage,
|
||||
updateMessage,
|
||||
deleteMessagesSince,
|
||||
deleteMessages,
|
||||
} = require('./Message');
|
||||
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
|
||||
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
|
||||
const { File } = require('~/db/models');
|
||||
const { matchModelName, findMatchingPattern } = require('@librechat/api');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
const methods = createMethods(mongoose, {
|
||||
matchModelName,
|
||||
findMatchingPattern,
|
||||
getCache: getLogStores,
|
||||
});
|
||||
|
||||
const seedDatabase = async () => {
|
||||
await methods.initializeRoles();
|
||||
await methods.seedDefaultRoles();
|
||||
await methods.ensureDefaultCategories();
|
||||
await methods.seedSystemGrants();
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
...methods,
|
||||
seedDatabase,
|
||||
comparePassword,
|
||||
|
||||
getMessage,
|
||||
getMessages,
|
||||
saveMessage,
|
||||
recordMessage,
|
||||
updateMessage,
|
||||
deleteMessagesSince,
|
||||
deleteMessages,
|
||||
|
||||
getConvoTitle,
|
||||
getConvo,
|
||||
saveConvo,
|
||||
deleteConvos,
|
||||
|
||||
getPreset,
|
||||
getPresets,
|
||||
savePreset,
|
||||
deletePresets,
|
||||
|
||||
Files: File,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,24 +0,0 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { updateInterfacePermissions: updateInterfacePerms } = require('@librechat/api');
|
||||
const { getRoleByName, updateAccessPermissions } = require('./Role');
|
||||
|
||||
/**
|
||||
* Update interface permissions based on app configuration.
|
||||
* Must be done independently from loading the app config.
|
||||
* @param {AppConfig} appConfig
|
||||
*/
|
||||
async function updateInterfacePermissions(appConfig) {
|
||||
try {
|
||||
await updateInterfacePerms({
|
||||
appConfig,
|
||||
getRoleByName,
|
||||
updateAccessPermissions,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error updating interface permissions:', error);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
updateInterfacePermissions,
|
||||
};
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { logger, hashToken, getRandomValues } = require('@librechat/data-schemas');
|
||||
const { createToken, findToken } = require('~/models');
|
||||
|
||||
/**
|
||||
* @module inviteUser
|
||||
* @description This module provides functions to create and get user invites
|
||||
*/
|
||||
|
||||
/**
|
||||
* @function createInvite
|
||||
* @description This function creates a new user invite
|
||||
* @param {string} email - The email of the user to invite
|
||||
* @returns {Promise<Object>} A promise that resolves to the saved invite document
|
||||
* @throws {Error} If there is an error creating the invite
|
||||
*/
|
||||
const createInvite = async (email) => {
|
||||
try {
|
||||
const token = await getRandomValues(32);
|
||||
const hash = await hashToken(token);
|
||||
const encodedToken = encodeURIComponent(token);
|
||||
|
||||
const fakeUserId = new mongoose.Types.ObjectId();
|
||||
|
||||
await createToken({
|
||||
userId: fakeUserId,
|
||||
email,
|
||||
token: hash,
|
||||
createdAt: Date.now(),
|
||||
expiresIn: 604800,
|
||||
});
|
||||
|
||||
return encodedToken;
|
||||
} catch (error) {
|
||||
logger.error('[createInvite] Error creating invite', error);
|
||||
return { message: 'Error creating invite' };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @function getInvite
|
||||
* @description This function retrieves a user invite
|
||||
* @param {string} encodedToken - The token of the invite to retrieve
|
||||
* @param {string} email - The email of the user to validate
|
||||
* @returns {Promise<Object>} A promise that resolves to the retrieved invite document
|
||||
* @throws {Error} If there is an error retrieving the invite, if the invite does not exist, or if the email does not match
|
||||
*/
|
||||
const getInvite = async (encodedToken, email) => {
|
||||
try {
|
||||
const token = decodeURIComponent(encodedToken);
|
||||
const hash = await hashToken(token);
|
||||
const invite = await findToken({ token: hash, email });
|
||||
|
||||
if (!invite) {
|
||||
throw new Error('Invite not found or email does not match');
|
||||
}
|
||||
|
||||
return invite;
|
||||
} catch (error) {
|
||||
logger.error('[getInvite] Error getting invite:', error);
|
||||
return { error: true, message: error.message };
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
createInvite,
|
||||
getInvite,
|
||||
};
|
||||
|
|
@ -1,218 +0,0 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getCustomEndpointConfig } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
appendAgentIdSuffix,
|
||||
encodeEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const { getMCPServerTools } = require('~/server/services/Config');
|
||||
|
||||
const { mcp_all, mcp_delimiter } = Constants;
|
||||
|
||||
/**
|
||||
* Constant for added conversation agent ID
|
||||
*/
|
||||
const ADDED_AGENT_ID = 'added_agent';
|
||||
|
||||
/**
|
||||
* Get an agent document based on the provided ID.
|
||||
* @param {Object} searchParameter - The search parameters to find the agent.
|
||||
* @param {string} searchParameter.id - The ID of the agent.
|
||||
* @returns {Promise<import('librechat-data-provider').Agent|null>}
|
||||
*/
|
||||
let getAgent;
|
||||
|
||||
/**
|
||||
* Set the getAgent function (dependency injection to avoid circular imports)
|
||||
* @param {Function} fn
|
||||
*/
|
||||
const setGetAgent = (fn) => {
|
||||
getAgent = fn;
|
||||
};
|
||||
|
||||
/**
|
||||
* Load an agent from an added conversation (TConversation).
|
||||
* Used for multi-convo parallel agent execution.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {import('express').Request} params.req
|
||||
* @param {import('librechat-data-provider').TConversation} params.conversation - The added conversation
|
||||
* @param {import('librechat-data-provider').Agent} [params.primaryAgent] - The primary agent (used to duplicate tools when both are ephemeral)
|
||||
* @returns {Promise<import('librechat-data-provider').Agent|null>} The agent config as a plain object, or null if invalid.
|
||||
*/
|
||||
const loadAddedAgent = async ({ req, conversation, primaryAgent }) => {
|
||||
if (!conversation) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// If there's an agent_id, load the existing agent
|
||||
if (conversation.agent_id && !isEphemeralAgentId(conversation.agent_id)) {
|
||||
if (!getAgent) {
|
||||
throw new Error('getAgent not initialized - call setGetAgent first');
|
||||
}
|
||||
const agent = await getAgent({
|
||||
id: conversation.agent_id,
|
||||
});
|
||||
|
||||
if (!agent) {
|
||||
logger.warn(`[loadAddedAgent] Agent ${conversation.agent_id} not found`);
|
||||
return null;
|
||||
}
|
||||
|
||||
agent.version = agent.versions ? agent.versions.length : 0;
|
||||
// Append suffix to distinguish from primary agent (matches ephemeral format)
|
||||
// This is needed when both agents have the same ID or for consistent parallel content attribution
|
||||
agent.id = appendAgentIdSuffix(agent.id, 1);
|
||||
return agent;
|
||||
}
|
||||
|
||||
// Otherwise, create an ephemeral agent config from the conversation
|
||||
const { model, endpoint, promptPrefix, spec, ...rest } = conversation;
|
||||
|
||||
if (!endpoint || !model) {
|
||||
logger.warn('[loadAddedAgent] Missing required endpoint or model for ephemeral agent');
|
||||
return null;
|
||||
}
|
||||
|
||||
// If both primary and added agents are ephemeral, duplicate tools from primary agent
|
||||
const primaryIsEphemeral = primaryAgent && isEphemeralAgentId(primaryAgent.id);
|
||||
if (primaryIsEphemeral && Array.isArray(primaryAgent.tools)) {
|
||||
// Get endpoint config and model spec for display name fallbacks
|
||||
const appConfig = req.config;
|
||||
let endpointConfig = appConfig?.endpoints?.[endpoint];
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig });
|
||||
} catch (err) {
|
||||
logger.error('[loadAddedAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
// Look up model spec for label fallback
|
||||
const modelSpecs = appConfig?.modelSpecs?.list;
|
||||
const modelSpec = spec != null && spec !== '' ? modelSpecs?.find((s) => s.name === spec) : null;
|
||||
|
||||
// For ephemeral agents, use modelLabel if provided, then model spec's label,
|
||||
// then modelDisplayLabel from endpoint config, otherwise empty string to show model name
|
||||
const sender = rest.modelLabel ?? modelSpec?.label ?? endpointConfig?.modelDisplayLabel ?? '';
|
||||
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender, index: 1 });
|
||||
|
||||
return {
|
||||
id: ephemeralId,
|
||||
instructions: promptPrefix || '',
|
||||
provider: endpoint,
|
||||
model_parameters: {},
|
||||
model,
|
||||
tools: [...primaryAgent.tools],
|
||||
};
|
||||
}
|
||||
|
||||
// Extract ephemeral agent options from conversation if present
|
||||
const ephemeralAgent = rest.ephemeralAgent;
|
||||
const mcpServers = new Set(ephemeralAgent?.mcp);
|
||||
const userId = req.user?.id;
|
||||
|
||||
// Check model spec for MCP servers
|
||||
const modelSpecs = req.config?.modelSpecs?.list;
|
||||
let modelSpec = null;
|
||||
if (spec != null && spec !== '') {
|
||||
modelSpec = modelSpecs?.find((s) => s.name === spec) || null;
|
||||
}
|
||||
if (modelSpec?.mcpServers) {
|
||||
for (const mcpServer of modelSpec.mcpServers) {
|
||||
mcpServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {string[]} */
|
||||
const tools = [];
|
||||
if (ephemeralAgent?.execute_code === true || modelSpec?.executeCode === true) {
|
||||
tools.push(Tools.execute_code);
|
||||
}
|
||||
if (ephemeralAgent?.file_search === true || modelSpec?.fileSearch === true) {
|
||||
tools.push(Tools.file_search);
|
||||
}
|
||||
if (ephemeralAgent?.web_search === true || modelSpec?.webSearch === true) {
|
||||
tools.push(Tools.web_search);
|
||||
}
|
||||
|
||||
const addedServers = new Set();
|
||||
if (mcpServers.size > 0) {
|
||||
for (const mcpServer of mcpServers) {
|
||||
if (addedServers.has(mcpServer)) {
|
||||
continue;
|
||||
}
|
||||
const serverTools = await getMCPServerTools(userId, mcpServer);
|
||||
if (!serverTools) {
|
||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||
addedServers.add(mcpServer);
|
||||
continue;
|
||||
}
|
||||
tools.push(...Object.keys(serverTools));
|
||||
addedServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
|
||||
// Build model_parameters from conversation fields
|
||||
const model_parameters = {};
|
||||
const paramKeys = [
|
||||
'temperature',
|
||||
'top_p',
|
||||
'topP',
|
||||
'topK',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'maxOutputTokens',
|
||||
'maxTokens',
|
||||
'max_tokens',
|
||||
];
|
||||
|
||||
for (const key of paramKeys) {
|
||||
if (rest[key] != null) {
|
||||
model_parameters[key] = rest[key];
|
||||
}
|
||||
}
|
||||
|
||||
// Get endpoint config for modelDisplayLabel fallback
|
||||
const appConfig = req.config;
|
||||
let endpointConfig = appConfig?.endpoints?.[endpoint];
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig });
|
||||
} catch (err) {
|
||||
logger.error('[loadAddedAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
// For ephemeral agents, use modelLabel if provided, then model spec's label,
|
||||
// then modelDisplayLabel from endpoint config, otherwise empty string to show model name
|
||||
const sender = rest.modelLabel ?? modelSpec?.label ?? endpointConfig?.modelDisplayLabel ?? '';
|
||||
|
||||
/** Encoded ephemeral agent ID with endpoint, model, sender, and index=1 to distinguish from primary */
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender, index: 1 });
|
||||
|
||||
const result = {
|
||||
id: ephemeralId,
|
||||
instructions: promptPrefix || '',
|
||||
provider: endpoint,
|
||||
model_parameters,
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
|
||||
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
|
||||
result.artifacts = ephemeralAgent.artifacts;
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
ADDED_AGENT_ID,
|
||||
loadAddedAgent,
|
||||
setGetAgent,
|
||||
};
|
||||
|
|
@ -1,140 +0,0 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTransaction, createStructuredTransaction } = require('./Transaction');
|
||||
/**
|
||||
* Creates up to two transactions to record the spending of tokens.
|
||||
*
|
||||
* @function
|
||||
* @async
|
||||
* @param {txData} txData - Transaction data.
|
||||
* @param {Object} tokenUsage - The number of tokens used.
|
||||
* @param {Number} tokenUsage.promptTokens - The number of prompt tokens used.
|
||||
* @param {Number} tokenUsage.completionTokens - The number of completion tokens used.
|
||||
* @returns {Promise<void>} - Returns nothing.
|
||||
* @throws {Error} - Throws an error if there's an issue creating the transactions.
|
||||
*/
|
||||
const spendTokens = async (txData, tokenUsage) => {
|
||||
const { promptTokens, completionTokens } = tokenUsage;
|
||||
logger.debug(
|
||||
`[spendTokens] conversationId: ${txData.conversationId}${
|
||||
txData?.context ? ` | Context: ${txData?.context}` : ''
|
||||
} | Token usage: `,
|
||||
{
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
},
|
||||
);
|
||||
let prompt, completion;
|
||||
const normalizedPromptTokens = Math.max(promptTokens ?? 0, 0);
|
||||
try {
|
||||
if (promptTokens !== undefined) {
|
||||
prompt = await createTransaction({
|
||||
...txData,
|
||||
tokenType: 'prompt',
|
||||
rawAmount: promptTokens === 0 ? 0 : -normalizedPromptTokens,
|
||||
inputTokenCount: normalizedPromptTokens,
|
||||
});
|
||||
}
|
||||
|
||||
if (completionTokens !== undefined) {
|
||||
completion = await createTransaction({
|
||||
...txData,
|
||||
tokenType: 'completion',
|
||||
rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0),
|
||||
inputTokenCount: normalizedPromptTokens,
|
||||
});
|
||||
}
|
||||
|
||||
if (prompt || completion) {
|
||||
logger.debug('[spendTokens] Transaction data record against balance:', {
|
||||
user: txData.user,
|
||||
prompt: prompt?.prompt,
|
||||
promptRate: prompt?.rate,
|
||||
completion: completion?.completion,
|
||||
completionRate: completion?.rate,
|
||||
balance: completion?.balance ?? prompt?.balance,
|
||||
});
|
||||
} else {
|
||||
logger.debug('[spendTokens] No transactions incurred against balance');
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('[spendTokens]', err);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates transactions to record the spending of structured tokens.
|
||||
*
|
||||
* @function
|
||||
* @async
|
||||
* @param {txData} txData - Transaction data.
|
||||
* @param {Object} tokenUsage - The number of tokens used.
|
||||
* @param {Object} tokenUsage.promptTokens - The number of prompt tokens used.
|
||||
* @param {Number} tokenUsage.promptTokens.input - The number of input tokens.
|
||||
* @param {Number} tokenUsage.promptTokens.write - The number of write tokens.
|
||||
* @param {Number} tokenUsage.promptTokens.read - The number of read tokens.
|
||||
* @param {Number} tokenUsage.completionTokens - The number of completion tokens used.
|
||||
* @returns {Promise<void>} - Returns nothing.
|
||||
* @throws {Error} - Throws an error if there's an issue creating the transactions.
|
||||
*/
|
||||
const spendStructuredTokens = async (txData, tokenUsage) => {
|
||||
const { promptTokens, completionTokens } = tokenUsage;
|
||||
logger.debug(
|
||||
`[spendStructuredTokens] conversationId: ${txData.conversationId}${
|
||||
txData?.context ? ` | Context: ${txData?.context}` : ''
|
||||
} | Token usage: `,
|
||||
{
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
},
|
||||
);
|
||||
let prompt, completion;
|
||||
try {
|
||||
if (promptTokens) {
|
||||
const input = Math.max(promptTokens.input ?? 0, 0);
|
||||
const write = Math.max(promptTokens.write ?? 0, 0);
|
||||
const read = Math.max(promptTokens.read ?? 0, 0);
|
||||
const totalInputTokens = input + write + read;
|
||||
prompt = await createStructuredTransaction({
|
||||
...txData,
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -input,
|
||||
writeTokens: -write,
|
||||
readTokens: -read,
|
||||
inputTokenCount: totalInputTokens,
|
||||
});
|
||||
}
|
||||
|
||||
if (completionTokens) {
|
||||
const totalInputTokens = promptTokens
|
||||
? Math.max(promptTokens.input ?? 0, 0) +
|
||||
Math.max(promptTokens.write ?? 0, 0) +
|
||||
Math.max(promptTokens.read ?? 0, 0)
|
||||
: undefined;
|
||||
completion = await createTransaction({
|
||||
...txData,
|
||||
tokenType: 'completion',
|
||||
rawAmount: -Math.max(completionTokens, 0),
|
||||
inputTokenCount: totalInputTokens,
|
||||
});
|
||||
}
|
||||
|
||||
if (prompt || completion) {
|
||||
logger.debug('[spendStructuredTokens] Transaction data record against balance:', {
|
||||
user: txData.user,
|
||||
prompt: prompt?.prompt,
|
||||
promptRate: prompt?.rate,
|
||||
completion: completion?.completion,
|
||||
completionRate: completion?.rate,
|
||||
balance: completion?.balance ?? prompt?.balance,
|
||||
});
|
||||
} else {
|
||||
logger.debug('[spendStructuredTokens] No transactions incurred against balance');
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('[spendStructuredTokens]', err);
|
||||
}
|
||||
|
||||
return { prompt, completion };
|
||||
};
|
||||
|
||||
module.exports = { spendTokens, spendStructuredTokens };
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
const bcrypt = require('bcryptjs');
|
||||
|
||||
/**
|
||||
* Compares the provided password with the user's password.
|
||||
*
|
||||
* @param {IUser} user - The user to compare the password for.
|
||||
* @param {string} candidatePassword - The password to test against the user's password.
|
||||
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the password matches.
|
||||
*/
|
||||
const comparePassword = async (user, candidatePassword) => {
|
||||
if (!user) {
|
||||
throw new Error('No user provided');
|
||||
}
|
||||
|
||||
if (!user.password) {
|
||||
throw new Error('No password, likely an email first registered via Social/OIDC login');
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
bcrypt.compare(candidatePassword, user.password, (err, isMatch) => {
|
||||
if (err) {
|
||||
reject(err);
|
||||
}
|
||||
resolve(isMatch);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
comparePassword,
|
||||
};
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "v0.8.3-rc2",
|
||||
"version": "v0.8.4",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
|
|
@ -44,13 +44,14 @@
|
|||
"@google/genai": "^1.19.0",
|
||||
"@keyv/redis": "^4.3.3",
|
||||
"@langchain/core": "^0.3.80",
|
||||
"@librechat/agents": "^3.1.55",
|
||||
"@librechat/agents": "^3.1.62",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@microsoft/microsoft-graph-client": "^3.0.7",
|
||||
"@modelcontextprotocol/sdk": "^1.27.1",
|
||||
"@node-saml/passport-saml": "^5.1.0",
|
||||
"@smithy/node-http-handler": "^4.4.5",
|
||||
"ai-tokenizer": "^1.0.6",
|
||||
"axios": "^1.13.5",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"compression": "^1.8.1",
|
||||
|
|
@ -63,10 +64,10 @@
|
|||
"eventsource": "^3.0.2",
|
||||
"express": "^5.2.1",
|
||||
"express-mongo-sanitize": "^2.2.0",
|
||||
"express-rate-limit": "^8.2.1",
|
||||
"express-rate-limit": "^8.3.0",
|
||||
"express-session": "^1.18.2",
|
||||
"express-static-gzip": "^2.2.0",
|
||||
"file-type": "^18.7.0",
|
||||
"file-type": "^21.3.2",
|
||||
"firebase": "^11.0.2",
|
||||
"form-data": "^4.0.4",
|
||||
"handlebars": "^4.7.7",
|
||||
|
|
@ -106,13 +107,13 @@
|
|||
"pdfjs-dist": "^5.4.624",
|
||||
"rate-limit-redis": "^4.2.0",
|
||||
"sharp": "^0.33.5",
|
||||
"tiktoken": "^1.0.15",
|
||||
"traverse": "^0.6.7",
|
||||
"ua-parser-js": "^1.0.36",
|
||||
"undici": "^7.18.2",
|
||||
"undici": "^7.24.1",
|
||||
"winston": "^3.11.0",
|
||||
"winston-daily-rotate-file": "^5.0.0",
|
||||
"xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz",
|
||||
"yauzl": "^3.2.1",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
|
|
|||
|
|
@ -119,14 +119,8 @@ const refreshController = async (req, res) => {
|
|||
|
||||
const token = setOpenIDAuthTokens(tokenset, req, res, user._id.toString(), refreshToken);
|
||||
|
||||
user.federatedTokens = {
|
||||
access_token: tokenset.access_token,
|
||||
id_token: tokenset.id_token,
|
||||
refresh_token: refreshToken,
|
||||
expires_at: claims.exp,
|
||||
};
|
||||
|
||||
return res.status(200).send({ token, user });
|
||||
const { password: _pw, __v: _v, totpSecret: _ts, backupCodes: _bc, ...safeUser } = user;
|
||||
return res.status(200).send({ token, user: safeUser });
|
||||
} catch (error) {
|
||||
logger.error('[refreshController] OpenID token refresh error', error);
|
||||
return res.status(403).send('Invalid OpenID refresh token');
|
||||
|
|
|
|||
|
|
@ -163,6 +163,16 @@ describe('refreshController – OpenID path', () => {
|
|||
exp: 9999999999,
|
||||
};
|
||||
|
||||
const defaultUser = {
|
||||
_id: 'user-db-id',
|
||||
email: baseClaims.email,
|
||||
openidId: baseClaims.sub,
|
||||
password: '$2b$10$hashedpassword',
|
||||
__v: 0,
|
||||
totpSecret: 'encrypted-totp-secret',
|
||||
backupCodes: ['hashed-code-1', 'hashed-code-2'],
|
||||
};
|
||||
|
||||
let req, res;
|
||||
|
||||
beforeEach(() => {
|
||||
|
|
@ -174,6 +184,7 @@ describe('refreshController – OpenID path', () => {
|
|||
mockTokenset.claims.mockReturnValue(baseClaims);
|
||||
getOpenIdEmail.mockReturnValue(baseClaims.email);
|
||||
setOpenIDAuthTokens.mockReturnValue('new-app-token');
|
||||
findOpenIDUser.mockResolvedValue({ user: { ...defaultUser }, error: null, migration: false });
|
||||
updateUser.mockResolvedValue({});
|
||||
|
||||
req = {
|
||||
|
|
@ -189,13 +200,6 @@ describe('refreshController – OpenID path', () => {
|
|||
});
|
||||
|
||||
it('should call getOpenIdEmail with token claims and use result for findOpenIDUser', async () => {
|
||||
const user = {
|
||||
_id: 'user-db-id',
|
||||
email: baseClaims.email,
|
||||
openidId: baseClaims.sub,
|
||||
};
|
||||
findOpenIDUser.mockResolvedValue({ user, error: null, migration: false });
|
||||
|
||||
await refreshController(req, res);
|
||||
|
||||
expect(getOpenIdEmail).toHaveBeenCalledWith(baseClaims);
|
||||
|
|
@ -229,13 +233,6 @@ describe('refreshController – OpenID path', () => {
|
|||
it('should fall back to claims.email when configured claim is absent from token claims', async () => {
|
||||
getOpenIdEmail.mockReturnValue(baseClaims.email);
|
||||
|
||||
const user = {
|
||||
_id: 'user-db-id',
|
||||
email: baseClaims.email,
|
||||
openidId: baseClaims.sub,
|
||||
};
|
||||
findOpenIDUser.mockResolvedValue({ user, error: null, migration: false });
|
||||
|
||||
await refreshController(req, res);
|
||||
|
||||
expect(findOpenIDUser).toHaveBeenCalledWith(
|
||||
|
|
@ -243,6 +240,25 @@ describe('refreshController – OpenID path', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should not expose sensitive fields or federatedTokens in refresh response', async () => {
|
||||
await refreshController(req, res);
|
||||
|
||||
const sentPayload = res.send.mock.calls[0][0];
|
||||
expect(sentPayload).toEqual({
|
||||
token: 'new-app-token',
|
||||
user: expect.objectContaining({
|
||||
_id: 'user-db-id',
|
||||
email: baseClaims.email,
|
||||
openidId: baseClaims.sub,
|
||||
}),
|
||||
});
|
||||
expect(sentPayload.user).not.toHaveProperty('federatedTokens');
|
||||
expect(sentPayload.user).not.toHaveProperty('password');
|
||||
expect(sentPayload.user).not.toHaveProperty('totpSecret');
|
||||
expect(sentPayload.user).not.toHaveProperty('backupCodes');
|
||||
expect(sentPayload.user).not.toHaveProperty('__v');
|
||||
});
|
||||
|
||||
it('should update openidId when migration is triggered on refresh', async () => {
|
||||
const user = { _id: 'user-db-id', email: baseClaims.email, openidId: null };
|
||||
findOpenIDUser.mockResolvedValue({ user, error: null, migration: true });
|
||||
|
|
|
|||
|
|
@ -1,24 +1,22 @@
|
|||
const { Balance } = require('~/db/models');
|
||||
const { findBalanceByUser } = require('~/models');
|
||||
|
||||
async function balanceController(req, res) {
|
||||
const balanceData = await Balance.findOne(
|
||||
{ user: req.user.id },
|
||||
'-_id tokenCredits autoRefillEnabled refillIntervalValue refillIntervalUnit lastRefill refillAmount',
|
||||
).lean();
|
||||
const balanceData = await findBalanceByUser(req.user.id);
|
||||
|
||||
if (!balanceData) {
|
||||
return res.status(404).json({ error: 'Balance not found' });
|
||||
}
|
||||
|
||||
// If auto-refill is not enabled, remove auto-refill related fields from the response
|
||||
if (!balanceData.autoRefillEnabled) {
|
||||
delete balanceData.refillIntervalValue;
|
||||
delete balanceData.refillIntervalUnit;
|
||||
delete balanceData.lastRefill;
|
||||
delete balanceData.refillAmount;
|
||||
const { _id: _, ...result } = balanceData;
|
||||
|
||||
if (!result.autoRefillEnabled) {
|
||||
delete result.refillIntervalValue;
|
||||
delete result.refillIntervalUnit;
|
||||
delete result.lastRefill;
|
||||
delete result.refillAmount;
|
||||
}
|
||||
|
||||
res.status(200).json(balanceData);
|
||||
res.status(200).json(result);
|
||||
}
|
||||
|
||||
module.exports = balanceController;
|
||||
|
|
|
|||
|
|
@ -9,22 +9,17 @@ const { enrichRemoteAgentPrincipals, backfillRemoteAgentPermissions } = require(
|
|||
const {
|
||||
bulkUpdateResourcePermissions,
|
||||
ensureGroupPrincipalExists,
|
||||
getResourcePermissionsMap,
|
||||
findAccessibleResources,
|
||||
getEffectivePermissions,
|
||||
ensurePrincipalExists,
|
||||
getAvailableRoles,
|
||||
findAccessibleResources,
|
||||
getResourcePermissionsMap,
|
||||
} = require('~/server/services/PermissionService');
|
||||
const {
|
||||
searchPrincipals: searchLocalPrincipals,
|
||||
sortPrincipalsByRelevance,
|
||||
calculateRelevanceScore,
|
||||
} = require('~/models');
|
||||
const {
|
||||
entraIdPrincipalFeatureEnabled,
|
||||
searchEntraIdPrincipals,
|
||||
} = require('~/server/services/GraphApiService');
|
||||
const { AclEntry, AccessRole } = require('~/db/models');
|
||||
const db = require('~/models');
|
||||
|
||||
/**
|
||||
* Generic controller for resource permission endpoints
|
||||
|
|
@ -155,6 +150,18 @@ const updateResourcePermissions = async (req, res) => {
|
|||
grantedBy: userId,
|
||||
});
|
||||
|
||||
const isAgentResource =
|
||||
resourceType === ResourceType.AGENT || resourceType === ResourceType.REMOTE_AGENT;
|
||||
const revokedUserIds = results.revoked
|
||||
.filter((p) => p.type === PrincipalType.USER && p.id)
|
||||
.map((p) => p.id);
|
||||
|
||||
if (isAgentResource && revokedUserIds.length > 0) {
|
||||
db.removeAgentFromUserFavorites(resourceId, revokedUserIds).catch((err) => {
|
||||
logger.error('[removeRevokedAgentFromFavorites] Error cleaning up favorites', err);
|
||||
});
|
||||
}
|
||||
|
||||
/** @type {TUpdateResourcePermissionsResponse} */
|
||||
const response = {
|
||||
message: 'Permissions updated successfully',
|
||||
|
|
@ -185,8 +192,7 @@ const getResourcePermissions = async (req, res) => {
|
|||
const { resourceType, resourceId } = req.params;
|
||||
validateResourceType(resourceType);
|
||||
|
||||
// Use aggregation pipeline for efficient single-query data retrieval
|
||||
const results = await AclEntry.aggregate([
|
||||
const results = await db.aggregateAclEntries([
|
||||
// Match ACL entries for this resource
|
||||
{
|
||||
$match: {
|
||||
|
|
@ -282,7 +288,12 @@ const getResourcePermissions = async (req, res) => {
|
|||
}
|
||||
|
||||
if (resourceType === ResourceType.REMOTE_AGENT) {
|
||||
const enricherDeps = { AclEntry, AccessRole, logger };
|
||||
const enricherDeps = {
|
||||
aggregateAclEntries: db.aggregateAclEntries,
|
||||
bulkWriteAclEntries: db.bulkWriteAclEntries,
|
||||
findRoleByIdentifier: db.findRoleByIdentifier,
|
||||
logger,
|
||||
};
|
||||
const enrichResult = await enrichRemoteAgentPrincipals(enricherDeps, resourceId, principals);
|
||||
principals = enrichResult.principals;
|
||||
backfillRemoteAgentPermissions(enricherDeps, resourceId, enrichResult.entriesToBackfill);
|
||||
|
|
@ -399,7 +410,7 @@ const searchPrincipals = async (req, res) => {
|
|||
typeFilters = validTypes.length > 0 ? validTypes : null;
|
||||
}
|
||||
|
||||
const localResults = await searchLocalPrincipals(query.trim(), searchLimit, typeFilters);
|
||||
const localResults = await db.searchPrincipals(query.trim(), searchLimit, typeFilters);
|
||||
let allPrincipals = [...localResults];
|
||||
|
||||
const useEntraId = entraIdPrincipalFeatureEnabled(req.user);
|
||||
|
|
@ -455,10 +466,11 @@ const searchPrincipals = async (req, res) => {
|
|||
}
|
||||
const scoredResults = allPrincipals.map((item) => ({
|
||||
...item,
|
||||
_searchScore: calculateRelevanceScore(item, query.trim()),
|
||||
_searchScore: db.calculateRelevanceScore(item, query.trim()),
|
||||
}));
|
||||
|
||||
const finalResults = sortPrincipalsByRelevance(scoredResults)
|
||||
const finalResults = db
|
||||
.sortPrincipalsByRelevance(scoredResults)
|
||||
.slice(0, searchLimit)
|
||||
.map((result) => {
|
||||
const { _searchScore, ...resultWithoutScore } = result;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const { encryptV3, logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
verifyOTPOrBackupCode,
|
||||
generateBackupCodes,
|
||||
generateTOTPSecret,
|
||||
verifyBackupCode,
|
||||
|
|
@ -13,24 +14,42 @@ const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, '');
|
|||
/**
|
||||
* Enable 2FA for the user by generating a new TOTP secret and backup codes.
|
||||
* The secret is encrypted and stored, and 2FA is marked as disabled until confirmed.
|
||||
* If 2FA is already enabled, requires OTP or backup code verification to re-enroll.
|
||||
*/
|
||||
const enable2FA = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const existingUser = await getUserById(
|
||||
userId,
|
||||
'+totpSecret +backupCodes _id twoFactorEnabled email',
|
||||
);
|
||||
|
||||
if (existingUser && existingUser.twoFactorEnabled) {
|
||||
const { token, backupCode } = req.body;
|
||||
const result = await verifyOTPOrBackupCode({
|
||||
user: existingUser,
|
||||
token,
|
||||
backupCode,
|
||||
persistBackupUse: false,
|
||||
});
|
||||
|
||||
if (!result.verified) {
|
||||
const msg = result.message ?? 'TOTP token or backup code is required to re-enroll 2FA';
|
||||
return res.status(result.status ?? 400).json({ message: msg });
|
||||
}
|
||||
}
|
||||
|
||||
const secret = generateTOTPSecret();
|
||||
const { plainCodes, codeObjects } = await generateBackupCodes();
|
||||
|
||||
// Encrypt the secret with v3 encryption before saving.
|
||||
const encryptedSecret = encryptV3(secret);
|
||||
|
||||
// Update the user record: store the secret & backup codes and set twoFactorEnabled to false.
|
||||
const user = await updateUser(userId, {
|
||||
totpSecret: encryptedSecret,
|
||||
backupCodes: codeObjects,
|
||||
twoFactorEnabled: false,
|
||||
pendingTotpSecret: encryptedSecret,
|
||||
pendingBackupCodes: codeObjects,
|
||||
});
|
||||
|
||||
const otpauthUrl = `otpauth://totp/${safeAppTitle}:${user.email}?secret=${secret}&issuer=${safeAppTitle}`;
|
||||
const email = user.email || (existingUser && existingUser.email) || '';
|
||||
const otpauthUrl = `otpauth://totp/${safeAppTitle}:${email}?secret=${secret}&issuer=${safeAppTitle}`;
|
||||
|
||||
return res.status(200).json({ otpauthUrl, backupCodes: plainCodes });
|
||||
} catch (err) {
|
||||
|
|
@ -46,13 +65,14 @@ const verify2FA = async (req, res) => {
|
|||
try {
|
||||
const userId = req.user.id;
|
||||
const { token, backupCode } = req.body;
|
||||
const user = await getUserById(userId, '_id totpSecret backupCodes');
|
||||
const user = await getUserById(userId, '+totpSecret +pendingTotpSecret +backupCodes _id');
|
||||
const secretSource = user?.pendingTotpSecret ?? user?.totpSecret;
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
if (!user || !secretSource) {
|
||||
return res.status(400).json({ message: '2FA not initiated' });
|
||||
}
|
||||
|
||||
const secret = await getTOTPSecret(user.totpSecret);
|
||||
const secret = await getTOTPSecret(secretSource);
|
||||
let isVerified = false;
|
||||
|
||||
if (token) {
|
||||
|
|
@ -78,15 +98,28 @@ const confirm2FA = async (req, res) => {
|
|||
try {
|
||||
const userId = req.user.id;
|
||||
const { token } = req.body;
|
||||
const user = await getUserById(userId, '_id totpSecret');
|
||||
const user = await getUserById(
|
||||
userId,
|
||||
'+totpSecret +pendingTotpSecret +pendingBackupCodes _id',
|
||||
);
|
||||
const secretSource = user?.pendingTotpSecret ?? user?.totpSecret;
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
if (!user || !secretSource) {
|
||||
return res.status(400).json({ message: '2FA not initiated' });
|
||||
}
|
||||
|
||||
const secret = await getTOTPSecret(user.totpSecret);
|
||||
const secret = await getTOTPSecret(secretSource);
|
||||
if (await verifyTOTP(secret, token)) {
|
||||
await updateUser(userId, { twoFactorEnabled: true });
|
||||
const update = {
|
||||
totpSecret: user.pendingTotpSecret ?? user.totpSecret,
|
||||
twoFactorEnabled: true,
|
||||
pendingTotpSecret: null,
|
||||
pendingBackupCodes: [],
|
||||
};
|
||||
if (user.pendingBackupCodes?.length) {
|
||||
update.backupCodes = user.pendingBackupCodes;
|
||||
}
|
||||
await updateUser(userId, update);
|
||||
return res.status(200).json();
|
||||
}
|
||||
return res.status(400).json({ message: 'Invalid token.' });
|
||||
|
|
@ -104,31 +137,27 @@ const disable2FA = async (req, res) => {
|
|||
try {
|
||||
const userId = req.user.id;
|
||||
const { token, backupCode } = req.body;
|
||||
const user = await getUserById(userId, '_id totpSecret backupCodes');
|
||||
const user = await getUserById(userId, '+totpSecret +backupCodes _id twoFactorEnabled');
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
return res.status(400).json({ message: '2FA is not setup for this user' });
|
||||
}
|
||||
|
||||
if (user.twoFactorEnabled) {
|
||||
const secret = await getTOTPSecret(user.totpSecret);
|
||||
let isVerified = false;
|
||||
const result = await verifyOTPOrBackupCode({ user, token, backupCode });
|
||||
|
||||
if (token) {
|
||||
isVerified = await verifyTOTP(secret, token);
|
||||
} else if (backupCode) {
|
||||
isVerified = await verifyBackupCode({ user, backupCode });
|
||||
} else {
|
||||
return res
|
||||
.status(400)
|
||||
.json({ message: 'Either token or backup code is required to disable 2FA' });
|
||||
}
|
||||
|
||||
if (!isVerified) {
|
||||
return res.status(401).json({ message: 'Invalid token or backup code' });
|
||||
if (!result.verified) {
|
||||
const msg = result.message ?? 'Either token or backup code is required to disable 2FA';
|
||||
return res.status(result.status ?? 400).json({ message: msg });
|
||||
}
|
||||
}
|
||||
await updateUser(userId, { totpSecret: null, backupCodes: [], twoFactorEnabled: false });
|
||||
await updateUser(userId, {
|
||||
totpSecret: null,
|
||||
backupCodes: [],
|
||||
twoFactorEnabled: false,
|
||||
pendingTotpSecret: null,
|
||||
pendingBackupCodes: [],
|
||||
});
|
||||
return res.status(200).json();
|
||||
} catch (err) {
|
||||
logger.error('[disable2FA]', err);
|
||||
|
|
@ -138,10 +167,28 @@ const disable2FA = async (req, res) => {
|
|||
|
||||
/**
|
||||
* Regenerate backup codes for the user.
|
||||
* Requires OTP or backup code verification if 2FA is already enabled.
|
||||
*/
|
||||
const regenerateBackupCodes = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const user = await getUserById(userId, '+totpSecret +backupCodes _id twoFactorEnabled');
|
||||
|
||||
if (!user) {
|
||||
return res.status(404).json({ message: 'User not found' });
|
||||
}
|
||||
|
||||
if (user.twoFactorEnabled) {
|
||||
const { token, backupCode } = req.body;
|
||||
const result = await verifyOTPOrBackupCode({ user, token, backupCode });
|
||||
|
||||
if (!result.verified) {
|
||||
const msg =
|
||||
result.message ?? 'TOTP token or backup code is required to regenerate backup codes';
|
||||
return res.status(result.status ?? 400).json({ message: msg });
|
||||
}
|
||||
}
|
||||
|
||||
const { plainCodes, codeObjects } = await generateBackupCodes();
|
||||
await updateUser(userId, { backupCodes: codeObjects });
|
||||
return res.status(200).json({
|
||||
|
|
|
|||
|
|
@ -1,49 +1,29 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { logger, webSearchKeys } = require('@librechat/data-schemas');
|
||||
const { Tools, CacheKeys, Constants, FileSources } = require('librechat-data-provider');
|
||||
const {
|
||||
getNewS3URL,
|
||||
needsRefresh,
|
||||
MCPOAuthHandler,
|
||||
MCPTokenStorage,
|
||||
normalizeHttpError,
|
||||
extractWebSearchEnvVars,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
deleteAllUserSessions,
|
||||
deleteAllSharedLinks,
|
||||
updateUserPlugins,
|
||||
deleteUserById,
|
||||
deleteMessages,
|
||||
deletePresets,
|
||||
deleteUserKey,
|
||||
deleteConvos,
|
||||
deleteFiles,
|
||||
updateUser,
|
||||
findToken,
|
||||
getFiles,
|
||||
} = require('~/models');
|
||||
const {
|
||||
ConversationTag,
|
||||
AgentApiKey,
|
||||
Transaction,
|
||||
MemoryEntry,
|
||||
Assistant,
|
||||
AclEntry,
|
||||
Balance,
|
||||
Action,
|
||||
Group,
|
||||
Token,
|
||||
User,
|
||||
} = require('~/db/models');
|
||||
Tools,
|
||||
CacheKeys,
|
||||
Constants,
|
||||
FileSources,
|
||||
ResourceType,
|
||||
} = require('librechat-data-provider');
|
||||
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
|
||||
const { verifyOTPOrBackupCode } = require('~/server/services/twoFactorService');
|
||||
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
|
||||
const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config');
|
||||
const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools');
|
||||
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
const { deleteUserPrompts } = require('~/models/Prompt');
|
||||
const { deleteUserAgents } = require('~/models/Agent');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const db = require('~/models');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
|
|
@ -64,7 +44,7 @@ const getUserController = async (req, res) => {
|
|||
const originalAvatar = userData.avatar;
|
||||
try {
|
||||
userData.avatar = await getNewS3URL(userData.avatar);
|
||||
await updateUser(userData.id, { avatar: userData.avatar });
|
||||
await db.updateUser(userData.id, { avatar: userData.avatar });
|
||||
} catch (error) {
|
||||
userData.avatar = originalAvatar;
|
||||
logger.error('Error getting new S3 URL for avatar:', error);
|
||||
|
|
@ -75,7 +55,7 @@ const getUserController = async (req, res) => {
|
|||
|
||||
const getTermsStatusController = async (req, res) => {
|
||||
try {
|
||||
const user = await User.findById(req.user.id);
|
||||
const user = await db.getUserById(req.user.id, 'termsAccepted');
|
||||
if (!user) {
|
||||
return res.status(404).json({ message: 'User not found' });
|
||||
}
|
||||
|
|
@ -88,7 +68,7 @@ const getTermsStatusController = async (req, res) => {
|
|||
|
||||
const acceptTermsController = async (req, res) => {
|
||||
try {
|
||||
const user = await User.findByIdAndUpdate(req.user.id, { termsAccepted: true }, { new: true });
|
||||
const user = await db.updateUser(req.user.id, { termsAccepted: true });
|
||||
if (!user) {
|
||||
return res.status(404).json({ message: 'User not found' });
|
||||
}
|
||||
|
|
@ -101,7 +81,7 @@ const acceptTermsController = async (req, res) => {
|
|||
|
||||
const deleteUserFiles = async (req) => {
|
||||
try {
|
||||
const userFiles = await getFiles({ user: req.user.id });
|
||||
const userFiles = await db.getFiles({ user: req.user.id });
|
||||
await processDeleteRequest({
|
||||
req,
|
||||
files: userFiles,
|
||||
|
|
@ -111,13 +91,86 @@ const deleteUserFiles = async (req) => {
|
|||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes MCP servers solely owned by the user and cleans up their ACLs.
|
||||
* Disconnects live sessions for deleted servers before removing DB records.
|
||||
* Servers with other owners are left intact; the caller is responsible for
|
||||
* removing the user's own ACL principal entries separately.
|
||||
*
|
||||
* Also handles legacy (pre-ACL) MCP servers that only have the author field set,
|
||||
* ensuring they are not orphaned if no permission migration has been run.
|
||||
* @param {string} userId - The ID of the user.
|
||||
*/
|
||||
const deleteUserMcpServers = async (userId) => {
|
||||
try {
|
||||
const MCPServer = mongoose.models.MCPServer;
|
||||
const AclEntry = mongoose.models.AclEntry;
|
||||
if (!MCPServer) {
|
||||
return;
|
||||
}
|
||||
|
||||
const userObjectId = new mongoose.Types.ObjectId(userId);
|
||||
const soleOwnedIds = await db.getSoleOwnedResourceIds(userObjectId, ResourceType.MCPSERVER);
|
||||
|
||||
const authoredServers = await MCPServer.find({ author: userObjectId })
|
||||
.select('_id serverName')
|
||||
.lean();
|
||||
|
||||
const migratedEntries =
|
||||
authoredServers.length > 0
|
||||
? await AclEntry.find({
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: { $in: authoredServers.map((s) => s._id) },
|
||||
})
|
||||
.select('resourceId')
|
||||
.lean()
|
||||
: [];
|
||||
const migratedIds = new Set(migratedEntries.map((e) => e.resourceId.toString()));
|
||||
const legacyServers = authoredServers.filter((s) => !migratedIds.has(s._id.toString()));
|
||||
const legacyServerIds = legacyServers.map((s) => s._id);
|
||||
|
||||
const allServerIdsToDelete = [...soleOwnedIds, ...legacyServerIds];
|
||||
|
||||
if (allServerIdsToDelete.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const aclOwnedServers =
|
||||
soleOwnedIds.length > 0
|
||||
? await MCPServer.find({ _id: { $in: soleOwnedIds } })
|
||||
.select('serverName')
|
||||
.lean()
|
||||
: [];
|
||||
const allServersToDelete = [...aclOwnedServers, ...legacyServers];
|
||||
|
||||
const mcpManager = getMCPManager();
|
||||
if (mcpManager) {
|
||||
await Promise.all(
|
||||
allServersToDelete.map(async (s) => {
|
||||
await mcpManager.disconnectUserConnection(userId, s.serverName);
|
||||
await invalidateCachedTools({ userId, serverName: s.serverName });
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
await AclEntry.deleteMany({
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: { $in: allServerIdsToDelete },
|
||||
});
|
||||
|
||||
await MCPServer.deleteMany({ _id: { $in: allServerIdsToDelete } });
|
||||
} catch (error) {
|
||||
logger.error('[deleteUserMcpServers] General error:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const updateUserPluginsController = async (req, res) => {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
const { user } = req;
|
||||
const { pluginKey, action, auth, isEntityTool } = req.body;
|
||||
try {
|
||||
if (!isEntityTool) {
|
||||
await updateUserPlugins(user._id, user.plugins, pluginKey, action);
|
||||
await db.updateUserPlugins(user._id, user.plugins, pluginKey, action);
|
||||
}
|
||||
|
||||
if (auth == null) {
|
||||
|
|
@ -241,37 +294,50 @@ const deleteUserController = async (req, res) => {
|
|||
const { user } = req;
|
||||
|
||||
try {
|
||||
await deleteMessages({ user: user.id }); // delete user messages
|
||||
await deleteAllUserSessions({ userId: user.id }); // delete user sessions
|
||||
await Transaction.deleteMany({ user: user.id }); // delete user transactions
|
||||
await deleteUserKey({ userId: user.id, all: true }); // delete user keys
|
||||
await Balance.deleteMany({ user: user._id }); // delete user balances
|
||||
await deletePresets(user.id); // delete user presets
|
||||
const existingUser = await db.getUserById(
|
||||
user.id,
|
||||
'+totpSecret +backupCodes _id twoFactorEnabled',
|
||||
);
|
||||
if (existingUser && existingUser.twoFactorEnabled) {
|
||||
const { token, backupCode } = req.body;
|
||||
const result = await verifyOTPOrBackupCode({ user: existingUser, token, backupCode });
|
||||
|
||||
if (!result.verified) {
|
||||
const msg =
|
||||
result.message ??
|
||||
'TOTP token or backup code is required to delete account with 2FA enabled';
|
||||
return res.status(result.status ?? 400).json({ message: msg });
|
||||
}
|
||||
}
|
||||
|
||||
await db.deleteMessages({ user: user.id });
|
||||
await db.deleteAllUserSessions({ userId: user.id });
|
||||
await db.deleteTransactions({ user: user.id });
|
||||
await db.deleteUserKey({ userId: user.id, all: true });
|
||||
await db.deleteBalances({ user: user._id });
|
||||
await db.deletePresets(user.id);
|
||||
try {
|
||||
await deleteConvos(user.id); // delete user convos
|
||||
await db.deleteConvos(user.id);
|
||||
} catch (error) {
|
||||
logger.error('[deleteUserController] Error deleting user convos, likely no convos', error);
|
||||
}
|
||||
await deleteUserPluginAuth(user.id, null, true); // delete user plugin auth
|
||||
await deleteUserById(user.id); // delete user
|
||||
await deleteAllSharedLinks(user.id); // delete user shared links
|
||||
await deleteUserFiles(req); // delete user files
|
||||
await deleteFiles(null, user.id); // delete database files in case of orphaned files from previous steps
|
||||
await deleteToolCalls(user.id); // delete user tool calls
|
||||
await deleteUserAgents(user.id); // delete user agents
|
||||
await AgentApiKey.deleteMany({ user: user._id }); // delete user agent API keys
|
||||
await Assistant.deleteMany({ user: user.id }); // delete user assistants
|
||||
await ConversationTag.deleteMany({ user: user.id }); // delete user conversation tags
|
||||
await MemoryEntry.deleteMany({ userId: user.id }); // delete user memory entries
|
||||
await deleteUserPrompts(req, user.id); // delete user prompts
|
||||
await Action.deleteMany({ user: user.id }); // delete user actions
|
||||
await Token.deleteMany({ userId: user.id }); // delete user OAuth tokens
|
||||
await Group.updateMany(
|
||||
// remove user from all groups
|
||||
{ memberIds: user.id },
|
||||
{ $pull: { memberIds: user.id } },
|
||||
);
|
||||
await AclEntry.deleteMany({ principalId: user._id }); // delete user ACL entries
|
||||
await deleteUserPluginAuth(user.id, null, true);
|
||||
await db.deleteUserById(user.id);
|
||||
await db.deleteAllSharedLinks(user.id);
|
||||
await deleteUserFiles(req);
|
||||
await db.deleteFiles(null, user.id);
|
||||
await db.deleteToolCalls(user.id);
|
||||
await db.deleteUserAgents(user.id);
|
||||
await db.deleteAllAgentApiKeys(user._id);
|
||||
await db.deleteAssistants({ user: user.id });
|
||||
await db.deleteConversationTags({ user: user.id });
|
||||
await db.deleteAllUserMemories(user.id);
|
||||
await db.deleteUserPrompts(user.id);
|
||||
await deleteUserMcpServers(user.id);
|
||||
await db.deleteActions({ user: user.id });
|
||||
await db.deleteTokens({ userId: user.id });
|
||||
await db.removeUserFromAllGroups(user.id);
|
||||
await db.deleteAclEntries({ principalId: user._id });
|
||||
logger.info(`User deleted account. Email: ${user.email} ID: ${user.id}`);
|
||||
res.status(200).send({ message: 'User deleted' });
|
||||
} catch (err) {
|
||||
|
|
@ -331,7 +397,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
|||
const clientTokenData = await MCPTokenStorage.getClientInfoAndMetadata({
|
||||
userId,
|
||||
serverName,
|
||||
findToken,
|
||||
findToken: db.findToken,
|
||||
});
|
||||
if (clientTokenData == null) {
|
||||
return;
|
||||
|
|
@ -342,7 +408,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
|||
const tokens = await MCPTokenStorage.getTokens({
|
||||
userId,
|
||||
serverName,
|
||||
findToken,
|
||||
findToken: db.findToken,
|
||||
});
|
||||
|
||||
// 3. revoke OAuth tokens at the provider
|
||||
|
|
@ -352,6 +418,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
|||
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
|
||||
clientMetadata.revocation_endpoint_auth_methods_supported;
|
||||
const oauthHeaders = serverConfig.oauth_headers ?? {};
|
||||
const allowedDomains = getMCPServersRegistry().getAllowedDomains();
|
||||
|
||||
if (tokens?.access_token) {
|
||||
try {
|
||||
|
|
@ -367,6 +434,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
|||
revocationEndpointAuthMethodsSupported,
|
||||
},
|
||||
oauthHeaders,
|
||||
allowedDomains,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
|
||||
|
|
@ -387,6 +455,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
|||
revocationEndpointAuthMethodsSupported,
|
||||
},
|
||||
oauthHeaders,
|
||||
allowedDomains,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);
|
||||
|
|
@ -398,7 +467,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
|||
userId,
|
||||
serverName,
|
||||
deleteToken: async (filter) => {
|
||||
await Token.deleteOne(filter);
|
||||
await db.deleteTokens(filter);
|
||||
},
|
||||
});
|
||||
|
||||
|
|
@ -418,4 +487,5 @@ module.exports = {
|
|||
verifyEmailController,
|
||||
updateUserPluginsController,
|
||||
resendVerificationController,
|
||||
deleteUserMcpServers,
|
||||
};
|
||||
|
|
|
|||
225
api/server/controllers/UserController.spec.js
Normal file
225
api/server/controllers/UserController.spec.js
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
const actual = jest.requireActual('@librechat/data-schemas');
|
||||
return {
|
||||
...actual,
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
info: jest.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('~/models', () => {
|
||||
const _mongoose = require('mongoose');
|
||||
return {
|
||||
deleteAllUserSessions: jest.fn().mockResolvedValue(undefined),
|
||||
deleteAllSharedLinks: jest.fn().mockResolvedValue(undefined),
|
||||
deleteAllAgentApiKeys: jest.fn().mockResolvedValue(undefined),
|
||||
deleteConversationTags: jest.fn().mockResolvedValue(undefined),
|
||||
deleteAllUserMemories: jest.fn().mockResolvedValue(undefined),
|
||||
deleteTransactions: jest.fn().mockResolvedValue(undefined),
|
||||
deleteAclEntries: jest.fn().mockResolvedValue(undefined),
|
||||
updateUserPlugins: jest.fn(),
|
||||
deleteAssistants: jest.fn().mockResolvedValue(undefined),
|
||||
deleteUserById: jest.fn().mockResolvedValue(undefined),
|
||||
deleteUserPrompts: jest.fn().mockResolvedValue(undefined),
|
||||
deleteMessages: jest.fn().mockResolvedValue(undefined),
|
||||
deleteBalances: jest.fn().mockResolvedValue(undefined),
|
||||
deleteActions: jest.fn().mockResolvedValue(undefined),
|
||||
deletePresets: jest.fn().mockResolvedValue(undefined),
|
||||
deleteUserKey: jest.fn().mockResolvedValue(undefined),
|
||||
deleteToolCalls: jest.fn().mockResolvedValue(undefined),
|
||||
deleteUserAgents: jest.fn().mockResolvedValue(undefined),
|
||||
deleteTokens: jest.fn().mockResolvedValue(undefined),
|
||||
deleteConvos: jest.fn().mockResolvedValue(undefined),
|
||||
deleteFiles: jest.fn().mockResolvedValue(undefined),
|
||||
updateUser: jest.fn(),
|
||||
getUserById: jest.fn().mockResolvedValue(null),
|
||||
findToken: jest.fn(),
|
||||
getFiles: jest.fn().mockResolvedValue([]),
|
||||
removeUserFromAllGroups: jest.fn().mockImplementation(async (userId) => {
|
||||
const Group = _mongoose.models.Group;
|
||||
await Group.updateMany({ memberIds: userId }, { $pullAll: { memberIds: [userId] } });
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('~/server/services/PluginService', () => ({
|
||||
updateUserPluginAuth: jest.fn(),
|
||||
deleteUserPluginAuth: jest.fn().mockResolvedValue(undefined),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/AuthService', () => ({
|
||||
verifyEmail: jest.fn(),
|
||||
resendVerificationEmail: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('sharp', () =>
|
||||
jest.fn(() => ({
|
||||
metadata: jest.fn().mockResolvedValue({}),
|
||||
toFormat: jest.fn().mockReturnThis(),
|
||||
toBuffer: jest.fn().mockResolvedValue(Buffer.alloc(0)),
|
||||
})),
|
||||
);
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
needsRefresh: jest.fn(),
|
||||
getNewS3URL: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
processDeleteRequest: jest.fn().mockResolvedValue(undefined),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getAppConfig: jest.fn().mockResolvedValue({}),
|
||||
getMCPManager: jest.fn(),
|
||||
getFlowStateManager: jest.fn(),
|
||||
getMCPServersRegistry: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(),
|
||||
}));
|
||||
|
||||
let mongoServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
});
|
||||
|
||||
const { deleteUserController } = require('./UserController');
|
||||
const { Group } = require('~/db/models');
|
||||
const { deleteConvos } = require('~/models');
|
||||
|
||||
describe('deleteUserController', () => {
|
||||
const mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
send: jest.fn().mockReturnThis(),
|
||||
json: jest.fn().mockReturnThis(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should return 200 on successful deletion', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const req = { user: { id: userId.toString(), _id: userId, email: 'test@test.com' } };
|
||||
|
||||
await deleteUserController(req, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({ message: 'User deleted' });
|
||||
});
|
||||
|
||||
it('should remove the user from all groups via $pullAll', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const userIdStr = userId.toString();
|
||||
const otherUser = new mongoose.Types.ObjectId().toString();
|
||||
|
||||
await Group.create([
|
||||
{ name: 'Group A', memberIds: [userIdStr, otherUser], source: 'local' },
|
||||
{ name: 'Group B', memberIds: [userIdStr], source: 'local' },
|
||||
{ name: 'Group C', memberIds: [otherUser], source: 'local' },
|
||||
]);
|
||||
|
||||
const req = { user: { id: userIdStr, _id: userId, email: 'del@test.com' } };
|
||||
await deleteUserController(req, mockRes);
|
||||
|
||||
const groups = await Group.find({}).sort({ name: 1 }).lean();
|
||||
expect(groups[0].memberIds).toEqual([otherUser]);
|
||||
expect(groups[1].memberIds).toEqual([]);
|
||||
expect(groups[2].memberIds).toEqual([otherUser]);
|
||||
});
|
||||
|
||||
it('should handle user that exists in no groups', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
await Group.create({ name: 'Empty', memberIds: ['someone-else'], source: 'local' });
|
||||
|
||||
const req = { user: { id: userId.toString(), _id: userId, email: 'no-groups@test.com' } };
|
||||
await deleteUserController(req, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const group = await Group.findOne({ name: 'Empty' }).lean();
|
||||
expect(group.memberIds).toEqual(['someone-else']);
|
||||
});
|
||||
|
||||
it('should remove duplicate memberIds if the user appears more than once', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const userIdStr = userId.toString();
|
||||
|
||||
await Group.create({
|
||||
name: 'Dupes',
|
||||
memberIds: [userIdStr, 'other', userIdStr],
|
||||
source: 'local',
|
||||
});
|
||||
|
||||
const req = { user: { id: userIdStr, _id: userId, email: 'dupe@test.com' } };
|
||||
await deleteUserController(req, mockRes);
|
||||
|
||||
const group = await Group.findOne({ name: 'Dupes' }).lean();
|
||||
expect(group.memberIds).toEqual(['other']);
|
||||
});
|
||||
|
||||
it('should still succeed when deleteConvos throws', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
deleteConvos.mockRejectedValueOnce(new Error('no convos'));
|
||||
|
||||
const req = { user: { id: userId.toString(), _id: userId, email: 'convos@test.com' } };
|
||||
await deleteUserController(req, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({ message: 'User deleted' });
|
||||
});
|
||||
|
||||
it('should return 500 when a critical operation fails', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const { deleteMessages } = require('~/models');
|
||||
deleteMessages.mockRejectedValueOnce(new Error('db down'));
|
||||
|
||||
const req = { user: { id: userId.toString(), _id: userId, email: 'fail@test.com' } };
|
||||
await deleteUserController(req, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({ message: 'Something went wrong.' });
|
||||
});
|
||||
|
||||
it('should use string user.id (not ObjectId user._id) for memberIds removal', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const userIdStr = userId.toString();
|
||||
const otherUser = 'other-user-id';
|
||||
|
||||
await Group.create({
|
||||
name: 'StringCheck',
|
||||
memberIds: [userIdStr, otherUser],
|
||||
source: 'local',
|
||||
});
|
||||
|
||||
const req = { user: { id: userIdStr, _id: userId, email: 'stringcheck@test.com' } };
|
||||
await deleteUserController(req, mockRes);
|
||||
|
||||
const group = await Group.findOne({ name: 'StringCheck' }).lean();
|
||||
expect(group.memberIds).toEqual([otherUser]);
|
||||
expect(group.memberIds).not.toContain(userIdStr);
|
||||
});
|
||||
});
|
||||
242
api/server/controllers/__tests__/PermissionsController.spec.js
Normal file
242
api/server/controllers/__tests__/PermissionsController.spec.js
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
const mongoose = require('mongoose');
|
||||
|
||||
const mockLogger = { error: jest.fn(), warn: jest.fn(), info: jest.fn(), debug: jest.fn() };
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: mockLogger,
|
||||
}));
|
||||
|
||||
const { ResourceType, PrincipalType } = jest.requireActual('librechat-data-provider');
|
||||
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
...jest.requireActual('librechat-data-provider'),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
enrichRemoteAgentPrincipals: jest.fn(),
|
||||
backfillRemoteAgentPermissions: jest.fn(),
|
||||
}));
|
||||
|
||||
const mockBulkUpdateResourcePermissions = jest.fn();
|
||||
|
||||
jest.mock('~/server/services/PermissionService', () => ({
|
||||
bulkUpdateResourcePermissions: (...args) => mockBulkUpdateResourcePermissions(...args),
|
||||
ensureGroupPrincipalExists: jest.fn(),
|
||||
getEffectivePermissions: jest.fn(),
|
||||
ensurePrincipalExists: jest.fn(),
|
||||
getAvailableRoles: jest.fn(),
|
||||
findAccessibleResources: jest.fn(),
|
||||
getResourcePermissionsMap: jest.fn(),
|
||||
}));
|
||||
|
||||
const mockRemoveAgentFromUserFavorites = jest.fn();
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
searchPrincipals: jest.fn(),
|
||||
sortPrincipalsByRelevance: jest.fn(),
|
||||
calculateRelevanceScore: jest.fn(),
|
||||
removeAgentFromUserFavorites: (...args) => mockRemoveAgentFromUserFavorites(...args),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/GraphApiService', () => ({
|
||||
entraIdPrincipalFeatureEnabled: jest.fn(() => false),
|
||||
searchEntraIdPrincipals: jest.fn(),
|
||||
}));
|
||||
|
||||
const { updateResourcePermissions } = require('../PermissionsController');
|
||||
|
||||
const createMockReq = (overrides = {}) => ({
|
||||
params: { resourceType: ResourceType.AGENT, resourceId: '507f1f77bcf86cd799439011' },
|
||||
body: { updated: [], removed: [], public: false },
|
||||
user: { id: 'user-1', role: 'USER' },
|
||||
headers: { authorization: '' },
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const createMockRes = () => {
|
||||
const res = {};
|
||||
res.status = jest.fn().mockReturnValue(res);
|
||||
res.json = jest.fn().mockReturnValue(res);
|
||||
return res;
|
||||
};
|
||||
|
||||
const flushPromises = () => new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
describe('PermissionsController', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('updateResourcePermissions — favorites cleanup', () => {
|
||||
const agentObjectId = new mongoose.Types.ObjectId().toString();
|
||||
const revokedUserId = new mongoose.Types.ObjectId().toString();
|
||||
|
||||
beforeEach(() => {
|
||||
mockBulkUpdateResourcePermissions.mockResolvedValue({
|
||||
granted: [],
|
||||
updated: [],
|
||||
revoked: [{ type: PrincipalType.USER, id: revokedUserId, name: 'Revoked User' }],
|
||||
errors: [],
|
||||
});
|
||||
|
||||
mockRemoveAgentFromUserFavorites.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
it('removes agent from revoked users favorites on AGENT resource type', async () => {
|
||||
const req = createMockReq({
|
||||
params: { resourceType: ResourceType.AGENT, resourceId: agentObjectId },
|
||||
body: {
|
||||
updated: [],
|
||||
removed: [{ type: PrincipalType.USER, id: revokedUserId }],
|
||||
public: false,
|
||||
},
|
||||
});
|
||||
const res = createMockRes();
|
||||
|
||||
await updateResourcePermissions(req, res);
|
||||
await flushPromises();
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRemoveAgentFromUserFavorites).toHaveBeenCalledWith(agentObjectId, [revokedUserId]);
|
||||
});
|
||||
|
||||
it('removes agent from revoked users favorites on REMOTE_AGENT resource type', async () => {
|
||||
const req = createMockReq({
|
||||
params: { resourceType: ResourceType.REMOTE_AGENT, resourceId: agentObjectId },
|
||||
body: {
|
||||
updated: [],
|
||||
removed: [{ type: PrincipalType.USER, id: revokedUserId }],
|
||||
public: false,
|
||||
},
|
||||
});
|
||||
const res = createMockRes();
|
||||
|
||||
await updateResourcePermissions(req, res);
|
||||
await flushPromises();
|
||||
|
||||
expect(mockRemoveAgentFromUserFavorites).toHaveBeenCalledWith(agentObjectId, [revokedUserId]);
|
||||
});
|
||||
|
||||
it('uses results.revoked (validated) not raw request payload', async () => {
|
||||
const validId = new mongoose.Types.ObjectId().toString();
|
||||
const invalidId = 'not-a-valid-id';
|
||||
|
||||
mockBulkUpdateResourcePermissions.mockResolvedValue({
|
||||
granted: [],
|
||||
updated: [],
|
||||
revoked: [{ type: PrincipalType.USER, id: validId }],
|
||||
errors: [{ principal: { type: PrincipalType.USER, id: invalidId }, error: 'Invalid ID' }],
|
||||
});
|
||||
|
||||
const req = createMockReq({
|
||||
params: { resourceType: ResourceType.AGENT, resourceId: agentObjectId },
|
||||
body: {
|
||||
updated: [],
|
||||
removed: [
|
||||
{ type: PrincipalType.USER, id: validId },
|
||||
{ type: PrincipalType.USER, id: invalidId },
|
||||
],
|
||||
public: false,
|
||||
},
|
||||
});
|
||||
const res = createMockRes();
|
||||
|
||||
await updateResourcePermissions(req, res);
|
||||
await flushPromises();
|
||||
|
||||
expect(mockRemoveAgentFromUserFavorites).toHaveBeenCalledWith(agentObjectId, [validId]);
|
||||
});
|
||||
|
||||
it('skips cleanup when no USER principals are revoked', async () => {
|
||||
mockBulkUpdateResourcePermissions.mockResolvedValue({
|
||||
granted: [],
|
||||
updated: [],
|
||||
revoked: [{ type: PrincipalType.GROUP, id: 'group-1' }],
|
||||
errors: [],
|
||||
});
|
||||
|
||||
const req = createMockReq({
|
||||
params: { resourceType: ResourceType.AGENT, resourceId: agentObjectId },
|
||||
body: {
|
||||
updated: [],
|
||||
removed: [{ type: PrincipalType.GROUP, id: 'group-1' }],
|
||||
public: false,
|
||||
},
|
||||
});
|
||||
const res = createMockRes();
|
||||
|
||||
await updateResourcePermissions(req, res);
|
||||
await flushPromises();
|
||||
|
||||
expect(mockRemoveAgentFromUserFavorites).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('skips cleanup for non-agent resource types', async () => {
|
||||
mockBulkUpdateResourcePermissions.mockResolvedValue({
|
||||
granted: [],
|
||||
updated: [],
|
||||
revoked: [{ type: PrincipalType.USER, id: revokedUserId }],
|
||||
errors: [],
|
||||
});
|
||||
|
||||
const req = createMockReq({
|
||||
params: { resourceType: ResourceType.PROMPTGROUP, resourceId: agentObjectId },
|
||||
body: {
|
||||
updated: [],
|
||||
removed: [{ type: PrincipalType.USER, id: revokedUserId }],
|
||||
public: false,
|
||||
},
|
||||
});
|
||||
const res = createMockRes();
|
||||
|
||||
await updateResourcePermissions(req, res);
|
||||
await flushPromises();
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRemoveAgentFromUserFavorites).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('handles agent not found gracefully', async () => {
|
||||
mockRemoveAgentFromUserFavorites.mockResolvedValue(undefined);
|
||||
|
||||
const req = createMockReq({
|
||||
params: { resourceType: ResourceType.AGENT, resourceId: agentObjectId },
|
||||
body: {
|
||||
updated: [],
|
||||
removed: [{ type: PrincipalType.USER, id: revokedUserId }],
|
||||
public: false,
|
||||
},
|
||||
});
|
||||
const res = createMockRes();
|
||||
|
||||
await updateResourcePermissions(req, res);
|
||||
await flushPromises();
|
||||
|
||||
expect(mockRemoveAgentFromUserFavorites).toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
});
|
||||
|
||||
it('logs error when removeAgentFromUserFavorites fails without blocking response', async () => {
|
||||
mockRemoveAgentFromUserFavorites.mockRejectedValue(new Error('DB connection lost'));
|
||||
|
||||
const req = createMockReq({
|
||||
params: { resourceType: ResourceType.AGENT, resourceId: agentObjectId },
|
||||
body: {
|
||||
updated: [],
|
||||
removed: [{ type: PrincipalType.USER, id: revokedUserId }],
|
||||
public: false,
|
||||
},
|
||||
});
|
||||
const res = createMockRes();
|
||||
|
||||
await updateResourcePermissions(req, res);
|
||||
await flushPromises();
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
expect(mockLogger.error).toHaveBeenCalledWith(
|
||||
'[removeRevokedAgentFromFavorites] Error cleaning up favorites',
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
264
api/server/controllers/__tests__/TwoFactorController.spec.js
Normal file
264
api/server/controllers/__tests__/TwoFactorController.spec.js
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
const mockGetUserById = jest.fn();
|
||||
const mockUpdateUser = jest.fn();
|
||||
const mockVerifyOTPOrBackupCode = jest.fn();
|
||||
const mockGenerateTOTPSecret = jest.fn();
|
||||
const mockGenerateBackupCodes = jest.fn();
|
||||
const mockEncryptV3 = jest.fn();
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
encryptV3: (...args) => mockEncryptV3(...args),
|
||||
logger: { error: jest.fn() },
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/twoFactorService', () => ({
|
||||
verifyOTPOrBackupCode: (...args) => mockVerifyOTPOrBackupCode(...args),
|
||||
generateBackupCodes: (...args) => mockGenerateBackupCodes(...args),
|
||||
generateTOTPSecret: (...args) => mockGenerateTOTPSecret(...args),
|
||||
verifyBackupCode: jest.fn(),
|
||||
getTOTPSecret: jest.fn(),
|
||||
verifyTOTP: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getUserById: (...args) => mockGetUserById(...args),
|
||||
updateUser: (...args) => mockUpdateUser(...args),
|
||||
}));
|
||||
|
||||
const { enable2FA, regenerateBackupCodes } = require('~/server/controllers/TwoFactorController');
|
||||
|
||||
function createRes() {
|
||||
const res = {};
|
||||
res.status = jest.fn().mockReturnValue(res);
|
||||
res.json = jest.fn().mockReturnValue(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
const PLAIN_CODES = ['code1', 'code2', 'code3'];
|
||||
const CODE_OBJECTS = [
|
||||
{ codeHash: 'h1', used: false, usedAt: null },
|
||||
{ codeHash: 'h2', used: false, usedAt: null },
|
||||
{ codeHash: 'h3', used: false, usedAt: null },
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockGenerateTOTPSecret.mockReturnValue('NEWSECRET');
|
||||
mockGenerateBackupCodes.mockResolvedValue({ plainCodes: PLAIN_CODES, codeObjects: CODE_OBJECTS });
|
||||
mockEncryptV3.mockReturnValue('encrypted-secret');
|
||||
});
|
||||
|
||||
describe('enable2FA', () => {
|
||||
it('allows first-time setup without token — writes to pending fields', async () => {
|
||||
const req = { user: { id: 'user1' }, body: {} };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue({ _id: 'user1', twoFactorEnabled: false, email: 'a@b.com' });
|
||||
mockUpdateUser.mockResolvedValue({ email: 'a@b.com' });
|
||||
|
||||
await enable2FA(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
expect(res.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ otpauthUrl: expect.any(String), backupCodes: PLAIN_CODES }),
|
||||
);
|
||||
expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled();
|
||||
const updateCall = mockUpdateUser.mock.calls[0][1];
|
||||
expect(updateCall).toHaveProperty('pendingTotpSecret', 'encrypted-secret');
|
||||
expect(updateCall).toHaveProperty('pendingBackupCodes', CODE_OBJECTS);
|
||||
expect(updateCall).not.toHaveProperty('twoFactorEnabled');
|
||||
expect(updateCall).not.toHaveProperty('totpSecret');
|
||||
expect(updateCall).not.toHaveProperty('backupCodes');
|
||||
});
|
||||
|
||||
it('re-enrollment writes to pending fields, leaving live 2FA intact', async () => {
|
||||
const req = { user: { id: 'user1' }, body: { token: '123456' } };
|
||||
const res = createRes();
|
||||
const existingUser = {
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
email: 'a@b.com',
|
||||
};
|
||||
mockGetUserById.mockResolvedValue(existingUser);
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
|
||||
mockUpdateUser.mockResolvedValue({ email: 'a@b.com' });
|
||||
|
||||
await enable2FA(req, res);
|
||||
|
||||
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
|
||||
user: existingUser,
|
||||
token: '123456',
|
||||
backupCode: undefined,
|
||||
persistBackupUse: false,
|
||||
});
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
const updateCall = mockUpdateUser.mock.calls[0][1];
|
||||
expect(updateCall).toHaveProperty('pendingTotpSecret', 'encrypted-secret');
|
||||
expect(updateCall).toHaveProperty('pendingBackupCodes', CODE_OBJECTS);
|
||||
expect(updateCall).not.toHaveProperty('twoFactorEnabled');
|
||||
expect(updateCall).not.toHaveProperty('totpSecret');
|
||||
});
|
||||
|
||||
it('allows re-enrollment with valid backup code (persistBackupUse: false)', async () => {
|
||||
const req = { user: { id: 'user1' }, body: { backupCode: 'backup123' } };
|
||||
const res = createRes();
|
||||
const existingUser = {
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
email: 'a@b.com',
|
||||
};
|
||||
mockGetUserById.mockResolvedValue(existingUser);
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
|
||||
mockUpdateUser.mockResolvedValue({ email: 'a@b.com' });
|
||||
|
||||
await enable2FA(req, res);
|
||||
|
||||
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ persistBackupUse: false }),
|
||||
);
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
});
|
||||
|
||||
it('returns error when no token provided and 2FA is enabled', async () => {
|
||||
const req = { user: { id: 'user1' }, body: {} };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue({
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
});
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 });
|
||||
|
||||
await enable2FA(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(400);
|
||||
expect(mockUpdateUser).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('returns 401 when invalid token provided and 2FA is enabled', async () => {
|
||||
const req = { user: { id: 'user1' }, body: { token: 'wrong' } };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue({
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
});
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({
|
||||
verified: false,
|
||||
status: 401,
|
||||
message: 'Invalid token or backup code',
|
||||
});
|
||||
|
||||
await enable2FA(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' });
|
||||
expect(mockUpdateUser).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('regenerateBackupCodes', () => {
|
||||
it('returns 404 when user not found', async () => {
|
||||
const req = { user: { id: 'user1' }, body: {} };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue(null);
|
||||
|
||||
await regenerateBackupCodes(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(404);
|
||||
expect(res.json).toHaveBeenCalledWith({ message: 'User not found' });
|
||||
});
|
||||
|
||||
it('requires OTP when 2FA is enabled', async () => {
|
||||
const req = { user: { id: 'user1' }, body: { token: '123456' } };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue({
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
});
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
|
||||
mockUpdateUser.mockResolvedValue({});
|
||||
|
||||
await regenerateBackupCodes(req, res);
|
||||
|
||||
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
backupCodes: PLAIN_CODES,
|
||||
backupCodesHash: CODE_OBJECTS,
|
||||
});
|
||||
});
|
||||
|
||||
it('returns error when no token provided and 2FA is enabled', async () => {
|
||||
const req = { user: { id: 'user1' }, body: {} };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue({
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
});
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 });
|
||||
|
||||
await regenerateBackupCodes(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(400);
|
||||
});
|
||||
|
||||
it('returns 401 when invalid token provided and 2FA is enabled', async () => {
|
||||
const req = { user: { id: 'user1' }, body: { token: 'wrong' } };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue({
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
});
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({
|
||||
verified: false,
|
||||
status: 401,
|
||||
message: 'Invalid token or backup code',
|
||||
});
|
||||
|
||||
await regenerateBackupCodes(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' });
|
||||
});
|
||||
|
||||
it('includes backupCodesHash in response', async () => {
|
||||
const req = { user: { id: 'user1' }, body: { token: '123456' } };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue({
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
});
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
|
||||
mockUpdateUser.mockResolvedValue({});
|
||||
|
||||
await regenerateBackupCodes(req, res);
|
||||
|
||||
const responseBody = res.json.mock.calls[0][0];
|
||||
expect(responseBody).toHaveProperty('backupCodesHash', CODE_OBJECTS);
|
||||
expect(responseBody).toHaveProperty('backupCodes', PLAIN_CODES);
|
||||
});
|
||||
|
||||
it('allows regeneration without token when 2FA is not enabled', async () => {
|
||||
const req = { user: { id: 'user1' }, body: {} };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue({
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: false,
|
||||
});
|
||||
mockUpdateUser.mockResolvedValue({});
|
||||
|
||||
await regenerateBackupCodes(req, res);
|
||||
|
||||
expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
backupCodes: PLAIN_CODES,
|
||||
backupCodesHash: CODE_OBJECTS,
|
||||
});
|
||||
});
|
||||
});
|
||||
287
api/server/controllers/__tests__/deleteUser.spec.js
Normal file
287
api/server/controllers/__tests__/deleteUser.spec.js
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
const mockGetUserById = jest.fn();
|
||||
const mockDeleteMessages = jest.fn();
|
||||
const mockDeleteAllUserSessions = jest.fn();
|
||||
const mockDeleteUserById = jest.fn();
|
||||
const mockDeleteAllSharedLinks = jest.fn();
|
||||
const mockDeletePresets = jest.fn();
|
||||
const mockDeleteUserKey = jest.fn();
|
||||
const mockDeleteConvos = jest.fn();
|
||||
const mockDeleteFiles = jest.fn();
|
||||
const mockGetFiles = jest.fn();
|
||||
const mockUpdateUserPlugins = jest.fn();
|
||||
const mockUpdateUser = jest.fn();
|
||||
const mockFindToken = jest.fn();
|
||||
const mockVerifyOTPOrBackupCode = jest.fn();
|
||||
const mockDeleteUserPluginAuth = jest.fn();
|
||||
const mockProcessDeleteRequest = jest.fn();
|
||||
const mockDeleteToolCalls = jest.fn();
|
||||
const mockDeleteUserAgents = jest.fn();
|
||||
const mockDeleteUserPrompts = jest.fn();
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: { error: jest.fn(), info: jest.fn() },
|
||||
webSearchKeys: [],
|
||||
}));
|
||||
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
Tools: {},
|
||||
CacheKeys: {},
|
||||
Constants: { mcp_delimiter: '::', mcp_prefix: 'mcp_' },
|
||||
FileSources: {},
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
MCPOAuthHandler: {},
|
||||
MCPTokenStorage: {},
|
||||
normalizeHttpError: jest.fn(),
|
||||
extractWebSearchEnvVars: jest.fn(),
|
||||
needsRefresh: jest.fn(),
|
||||
getNewS3URL: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
deleteAllUserSessions: (...args) => mockDeleteAllUserSessions(...args),
|
||||
deleteAllSharedLinks: (...args) => mockDeleteAllSharedLinks(...args),
|
||||
updateUserPlugins: (...args) => mockUpdateUserPlugins(...args),
|
||||
deleteUserById: (...args) => mockDeleteUserById(...args),
|
||||
deleteMessages: (...args) => mockDeleteMessages(...args),
|
||||
deletePresets: (...args) => mockDeletePresets(...args),
|
||||
deleteUserKey: (...args) => mockDeleteUserKey(...args),
|
||||
getUserById: (...args) => mockGetUserById(...args),
|
||||
deleteConvos: (...args) => mockDeleteConvos(...args),
|
||||
deleteFiles: (...args) => mockDeleteFiles(...args),
|
||||
updateUser: (...args) => mockUpdateUser(...args),
|
||||
findToken: (...args) => mockFindToken(...args),
|
||||
getFiles: (...args) => mockGetFiles(...args),
|
||||
deleteToolCalls: (...args) => mockDeleteToolCalls(...args),
|
||||
deleteUserAgents: (...args) => mockDeleteUserAgents(...args),
|
||||
deleteUserPrompts: (...args) => mockDeleteUserPrompts(...args),
|
||||
deleteTransactions: jest.fn(),
|
||||
deleteBalances: jest.fn(),
|
||||
deleteAllAgentApiKeys: jest.fn(),
|
||||
deleteAssistants: jest.fn(),
|
||||
deleteConversationTags: jest.fn(),
|
||||
deleteAllUserMemories: jest.fn(),
|
||||
deleteActions: jest.fn(),
|
||||
deleteTokens: jest.fn(),
|
||||
removeUserFromAllGroups: jest.fn(),
|
||||
deleteAclEntries: jest.fn(),
|
||||
getSoleOwnedResourceIds: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/PluginService', () => ({
|
||||
updateUserPluginAuth: jest.fn(),
|
||||
deleteUserPluginAuth: (...args) => mockDeleteUserPluginAuth(...args),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/twoFactorService', () => ({
|
||||
verifyOTPOrBackupCode: (...args) => mockVerifyOTPOrBackupCode(...args),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/AuthService', () => ({
|
||||
verifyEmail: jest.fn(),
|
||||
resendVerificationEmail: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPManager: jest.fn(),
|
||||
getFlowStateManager: jest.fn(),
|
||||
getMCPServersRegistry: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config/getCachedTools', () => ({
|
||||
invalidateCachedTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
processDeleteRequest: (...args) => mockProcessDeleteRequest(...args),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getAppConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(),
|
||||
}));
|
||||
|
||||
const { deleteUserController } = require('~/server/controllers/UserController');
|
||||
|
||||
function createRes() {
|
||||
const res = {};
|
||||
res.status = jest.fn().mockReturnValue(res);
|
||||
res.json = jest.fn().mockReturnValue(res);
|
||||
res.send = jest.fn().mockReturnValue(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
function stubDeletionMocks() {
|
||||
mockDeleteMessages.mockResolvedValue();
|
||||
mockDeleteAllUserSessions.mockResolvedValue();
|
||||
mockDeleteUserKey.mockResolvedValue();
|
||||
mockDeletePresets.mockResolvedValue();
|
||||
mockDeleteConvos.mockResolvedValue();
|
||||
mockDeleteUserPluginAuth.mockResolvedValue();
|
||||
mockDeleteUserById.mockResolvedValue();
|
||||
mockDeleteAllSharedLinks.mockResolvedValue();
|
||||
mockGetFiles.mockResolvedValue([]);
|
||||
mockProcessDeleteRequest.mockResolvedValue();
|
||||
mockDeleteFiles.mockResolvedValue();
|
||||
mockDeleteToolCalls.mockResolvedValue();
|
||||
mockDeleteUserAgents.mockResolvedValue();
|
||||
mockDeleteUserPrompts.mockResolvedValue();
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
stubDeletionMocks();
|
||||
});
|
||||
|
||||
describe('deleteUserController - 2FA enforcement', () => {
|
||||
it('proceeds with deletion when 2FA is not enabled', async () => {
|
||||
const req = { user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, body: {} };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue({ _id: 'user1', twoFactorEnabled: false });
|
||||
|
||||
await deleteUserController(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
|
||||
expect(mockDeleteMessages).toHaveBeenCalled();
|
||||
expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('proceeds with deletion when user has no 2FA record', async () => {
|
||||
const req = { user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, body: {} };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue(null);
|
||||
|
||||
await deleteUserController(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
|
||||
});
|
||||
|
||||
it('returns error when 2FA is enabled and verification fails with 400', async () => {
|
||||
const req = { user: { id: 'user1', _id: 'user1' }, body: {} };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue({
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
});
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 });
|
||||
|
||||
await deleteUserController(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(400);
|
||||
expect(mockDeleteMessages).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('returns 401 when 2FA is enabled and invalid TOTP token provided', async () => {
|
||||
const existingUser = {
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
};
|
||||
const req = { user: { id: 'user1', _id: 'user1' }, body: { token: 'wrong' } };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue(existingUser);
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({
|
||||
verified: false,
|
||||
status: 401,
|
||||
message: 'Invalid token or backup code',
|
||||
});
|
||||
|
||||
await deleteUserController(req, res);
|
||||
|
||||
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
|
||||
user: existingUser,
|
||||
token: 'wrong',
|
||||
backupCode: undefined,
|
||||
});
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' });
|
||||
expect(mockDeleteMessages).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('returns 401 when 2FA is enabled and invalid backup code provided', async () => {
|
||||
const existingUser = {
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
backupCodes: [],
|
||||
};
|
||||
const req = { user: { id: 'user1', _id: 'user1' }, body: { backupCode: 'bad-code' } };
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue(existingUser);
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({
|
||||
verified: false,
|
||||
status: 401,
|
||||
message: 'Invalid token or backup code',
|
||||
});
|
||||
|
||||
await deleteUserController(req, res);
|
||||
|
||||
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
|
||||
user: existingUser,
|
||||
token: undefined,
|
||||
backupCode: 'bad-code',
|
||||
});
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(mockDeleteMessages).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('deletes account when valid TOTP token provided with 2FA enabled', async () => {
|
||||
const existingUser = {
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
};
|
||||
const req = {
|
||||
user: { id: 'user1', _id: 'user1', email: 'a@b.com' },
|
||||
body: { token: '123456' },
|
||||
};
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue(existingUser);
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
|
||||
|
||||
await deleteUserController(req, res);
|
||||
|
||||
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
|
||||
user: existingUser,
|
||||
token: '123456',
|
||||
backupCode: undefined,
|
||||
});
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
|
||||
expect(mockDeleteMessages).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('deletes account when valid backup code provided with 2FA enabled', async () => {
|
||||
const existingUser = {
|
||||
_id: 'user1',
|
||||
twoFactorEnabled: true,
|
||||
totpSecret: 'enc-secret',
|
||||
backupCodes: [{ codeHash: 'h1', used: false }],
|
||||
};
|
||||
const req = {
|
||||
user: { id: 'user1', _id: 'user1', email: 'a@b.com' },
|
||||
body: { backupCode: 'valid-code' },
|
||||
};
|
||||
const res = createRes();
|
||||
mockGetUserById.mockResolvedValue(existingUser);
|
||||
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
|
||||
|
||||
await deleteUserController(req, res);
|
||||
|
||||
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
|
||||
user: existingUser,
|
||||
token: undefined,
|
||||
backupCode: 'valid-code',
|
||||
});
|
||||
expect(res.status).toHaveBeenCalledWith(200);
|
||||
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
|
||||
expect(mockDeleteMessages).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
319
api/server/controllers/__tests__/deleteUserMcpServers.spec.js
Normal file
319
api/server/controllers/__tests__/deleteUserMcpServers.spec.js
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
const mockGetMCPManager = jest.fn();
|
||||
const mockInvalidateCachedTools = jest.fn();
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPManager: (...args) => mockGetMCPManager(...args),
|
||||
getFlowStateManager: jest.fn(),
|
||||
getMCPServersRegistry: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config/getCachedTools', () => ({
|
||||
invalidateCachedTools: (...args) => mockInvalidateCachedTools(...args),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getAppConfig: jest.fn(),
|
||||
getMCPServerTools: jest.fn(),
|
||||
}));
|
||||
|
||||
const mongoose = require('mongoose');
|
||||
const { mcpServerSchema } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const {
|
||||
ResourceType,
|
||||
AccessRoleIds,
|
||||
PrincipalType,
|
||||
PermissionBits,
|
||||
} = require('librechat-data-provider');
|
||||
const permissionService = require('~/server/services/PermissionService');
|
||||
const { deleteUserMcpServers } = require('~/server/controllers/UserController');
|
||||
const { AclEntry, AccessRole } = require('~/db/models');
|
||||
|
||||
let MCPServer;
|
||||
|
||||
describe('deleteUserMcpServers', () => {
|
||||
let mongoServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
MCPServer = mongoose.models.MCPServer || mongoose.model('MCPServer', mcpServerSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.MCPSERVER_OWNER,
|
||||
name: 'MCP Server Owner',
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
permBits:
|
||||
PermissionBits.VIEW | PermissionBits.EDIT | PermissionBits.DELETE | PermissionBits.SHARE,
|
||||
});
|
||||
|
||||
await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.MCPSERVER_VIEWER,
|
||||
name: 'MCP Server Viewer',
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
permBits: PermissionBits.VIEW,
|
||||
});
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await MCPServer.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
test('should delete solely-owned MCP servers and their ACL entries', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
|
||||
const server = await MCPServer.create({
|
||||
serverName: 'sole-owned-server',
|
||||
config: { title: 'Test Server' },
|
||||
author: userId,
|
||||
});
|
||||
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: server._id,
|
||||
accessRoleId: AccessRoleIds.MCPSERVER_OWNER,
|
||||
grantedBy: userId,
|
||||
});
|
||||
|
||||
mockGetMCPManager.mockReturnValue({
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(undefined),
|
||||
});
|
||||
|
||||
await deleteUserMcpServers(userId.toString());
|
||||
|
||||
expect(await MCPServer.findById(server._id)).toBeNull();
|
||||
|
||||
const aclEntries = await AclEntry.find({
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: server._id,
|
||||
});
|
||||
expect(aclEntries).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('should disconnect MCP sessions and invalidate tool cache before deletion', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const mockDisconnect = jest.fn().mockResolvedValue(undefined);
|
||||
|
||||
const server = await MCPServer.create({
|
||||
serverName: 'session-server',
|
||||
config: { title: 'Session Server' },
|
||||
author: userId,
|
||||
});
|
||||
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: server._id,
|
||||
accessRoleId: AccessRoleIds.MCPSERVER_OWNER,
|
||||
grantedBy: userId,
|
||||
});
|
||||
|
||||
mockGetMCPManager.mockReturnValue({ disconnectUserConnection: mockDisconnect });
|
||||
|
||||
await deleteUserMcpServers(userId.toString());
|
||||
|
||||
expect(mockDisconnect).toHaveBeenCalledWith(userId.toString(), 'session-server');
|
||||
expect(mockInvalidateCachedTools).toHaveBeenCalledWith({
|
||||
userId: userId.toString(),
|
||||
serverName: 'session-server',
|
||||
});
|
||||
});
|
||||
|
||||
test('should preserve multi-owned MCP servers', async () => {
|
||||
const deletingUserId = new mongoose.Types.ObjectId();
|
||||
const otherOwnerId = new mongoose.Types.ObjectId();
|
||||
|
||||
const soleServer = await MCPServer.create({
|
||||
serverName: 'sole-server',
|
||||
config: { title: 'Sole Server' },
|
||||
author: deletingUserId,
|
||||
});
|
||||
|
||||
const multiServer = await MCPServer.create({
|
||||
serverName: 'multi-server',
|
||||
config: { title: 'Multi Server' },
|
||||
author: deletingUserId,
|
||||
});
|
||||
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: deletingUserId,
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: soleServer._id,
|
||||
accessRoleId: AccessRoleIds.MCPSERVER_OWNER,
|
||||
grantedBy: deletingUserId,
|
||||
});
|
||||
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: deletingUserId,
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: multiServer._id,
|
||||
accessRoleId: AccessRoleIds.MCPSERVER_OWNER,
|
||||
grantedBy: deletingUserId,
|
||||
});
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherOwnerId,
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: multiServer._id,
|
||||
accessRoleId: AccessRoleIds.MCPSERVER_OWNER,
|
||||
grantedBy: otherOwnerId,
|
||||
});
|
||||
|
||||
mockGetMCPManager.mockReturnValue({
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(undefined),
|
||||
});
|
||||
|
||||
await deleteUserMcpServers(deletingUserId.toString());
|
||||
|
||||
expect(await MCPServer.findById(soleServer._id)).toBeNull();
|
||||
expect(await MCPServer.findById(multiServer._id)).not.toBeNull();
|
||||
|
||||
const soleAcl = await AclEntry.find({
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: soleServer._id,
|
||||
});
|
||||
expect(soleAcl).toHaveLength(0);
|
||||
|
||||
const multiAclOther = await AclEntry.find({
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: multiServer._id,
|
||||
principalId: otherOwnerId,
|
||||
});
|
||||
expect(multiAclOther).toHaveLength(1);
|
||||
expect(multiAclOther[0].permBits & PermissionBits.DELETE).toBeTruthy();
|
||||
|
||||
const multiAclDeleting = await AclEntry.find({
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: multiServer._id,
|
||||
principalId: deletingUserId,
|
||||
});
|
||||
expect(multiAclDeleting).toHaveLength(1);
|
||||
});
|
||||
|
||||
test('should be a no-op when user has no owned MCP servers', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
|
||||
const otherUserId = new mongoose.Types.ObjectId();
|
||||
const server = await MCPServer.create({
|
||||
serverName: 'other-server',
|
||||
config: { title: 'Other Server' },
|
||||
author: otherUserId,
|
||||
});
|
||||
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUserId,
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: server._id,
|
||||
accessRoleId: AccessRoleIds.MCPSERVER_OWNER,
|
||||
grantedBy: otherUserId,
|
||||
});
|
||||
|
||||
await deleteUserMcpServers(userId.toString());
|
||||
|
||||
expect(await MCPServer.findById(server._id)).not.toBeNull();
|
||||
expect(mockGetMCPManager).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle gracefully when MCPServer model is not registered', async () => {
|
||||
const originalModel = mongoose.models.MCPServer;
|
||||
delete mongoose.models.MCPServer;
|
||||
|
||||
try {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
await expect(deleteUserMcpServers(userId.toString())).resolves.toBeUndefined();
|
||||
} finally {
|
||||
mongoose.models.MCPServer = originalModel;
|
||||
}
|
||||
});
|
||||
|
||||
test('should handle gracefully when MCPManager is not available', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
|
||||
const server = await MCPServer.create({
|
||||
serverName: 'no-manager-server',
|
||||
config: { title: 'No Manager Server' },
|
||||
author: userId,
|
||||
});
|
||||
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: server._id,
|
||||
accessRoleId: AccessRoleIds.MCPSERVER_OWNER,
|
||||
grantedBy: userId,
|
||||
});
|
||||
|
||||
mockGetMCPManager.mockReturnValue(null);
|
||||
|
||||
await deleteUserMcpServers(userId.toString());
|
||||
|
||||
expect(await MCPServer.findById(server._id)).toBeNull();
|
||||
});
|
||||
|
||||
test('should delete legacy MCP servers that have author but no ACL entries', async () => {
|
||||
const legacyUserId = new mongoose.Types.ObjectId();
|
||||
|
||||
const legacyServer = await MCPServer.create({
|
||||
serverName: 'legacy-server',
|
||||
config: { title: 'Legacy Server' },
|
||||
author: legacyUserId,
|
||||
});
|
||||
|
||||
mockGetMCPManager.mockReturnValue({
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(undefined),
|
||||
});
|
||||
|
||||
await deleteUserMcpServers(legacyUserId.toString());
|
||||
|
||||
expect(await MCPServer.findById(legacyServer._id)).toBeNull();
|
||||
});
|
||||
|
||||
test('should delete both ACL-owned and legacy servers in one call', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
|
||||
const aclServer = await MCPServer.create({
|
||||
serverName: 'acl-server',
|
||||
config: { title: 'ACL Server' },
|
||||
author: userId,
|
||||
});
|
||||
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
resourceId: aclServer._id,
|
||||
accessRoleId: AccessRoleIds.MCPSERVER_OWNER,
|
||||
grantedBy: userId,
|
||||
});
|
||||
|
||||
const legacyServer = await MCPServer.create({
|
||||
serverName: 'legacy-mixed-server',
|
||||
config: { title: 'Legacy Mixed' },
|
||||
author: userId,
|
||||
});
|
||||
|
||||
mockGetMCPManager.mockReturnValue({
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(undefined),
|
||||
});
|
||||
|
||||
await deleteUserMcpServers(userId.toString());
|
||||
|
||||
expect(await MCPServer.findById(aclServer._id)).toBeNull();
|
||||
expect(await MCPServer.findById(legacyServer._id)).toBeNull();
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { ResourceType } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Maps each ResourceType to the cleanup function name that must appear in
|
||||
* deleteUserController's source to prove it is handled during user deletion.
|
||||
*
|
||||
* When a new ResourceType is added, this test will fail until a corresponding
|
||||
* entry is added here (or to NO_USER_CLEANUP_NEEDED) AND the actual cleanup
|
||||
* logic is implemented.
|
||||
*/
|
||||
const HANDLED_RESOURCE_TYPES = {
|
||||
[ResourceType.AGENT]: 'deleteUserAgents',
|
||||
[ResourceType.REMOTE_AGENT]: 'deleteUserAgents',
|
||||
[ResourceType.PROMPTGROUP]: 'deleteUserPrompts',
|
||||
[ResourceType.MCPSERVER]: 'deleteUserMcpServers',
|
||||
};
|
||||
|
||||
/**
|
||||
* ResourceTypes that are ACL-tracked but have no per-user deletion semantics
|
||||
* (e.g., system resources, public-only). Must be explicitly listed here with
|
||||
* a justification to prevent silent omissions.
|
||||
*/
|
||||
const NO_USER_CLEANUP_NEEDED = new Set([
|
||||
// Example: ResourceType.SYSTEM_TEMPLATE — public/system; not user-owned
|
||||
]);
|
||||
|
||||
describe('deleteUserController - resource type coverage guard', () => {
|
||||
let controllerSource;
|
||||
|
||||
beforeAll(() => {
|
||||
controllerSource = fs.readFileSync(path.resolve(__dirname, '../UserController.js'), 'utf-8');
|
||||
});
|
||||
|
||||
test('every ResourceType must have a documented cleanup handler or explicit exclusion', () => {
|
||||
const allTypes = Object.values(ResourceType);
|
||||
const handledTypes = Object.keys(HANDLED_RESOURCE_TYPES);
|
||||
const unhandledTypes = allTypes.filter(
|
||||
(t) => !handledTypes.includes(t) && !NO_USER_CLEANUP_NEEDED.has(t),
|
||||
);
|
||||
|
||||
expect(unhandledTypes).toEqual([]);
|
||||
});
|
||||
|
||||
test('every cleanup handler referenced in HANDLED_RESOURCE_TYPES must appear in the controller source', () => {
|
||||
const uniqueHandlers = [...new Set(Object.values(HANDLED_RESOURCE_TYPES))];
|
||||
|
||||
for (const handler of uniqueHandlers) {
|
||||
expect(controllerSource).toContain(handler);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
|
@ -77,42 +77,25 @@ jest.mock('~/server/services/ToolService', () => ({
|
|||
loadToolsForExecution: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/spendTokens', () => ({
|
||||
spendTokens: mockSpendTokens,
|
||||
spendStructuredTokens: mockSpendStructuredTokens,
|
||||
}));
|
||||
|
||||
const mockGetMultiplier = jest.fn().mockReturnValue(1);
|
||||
const mockGetCacheMultiplier = jest.fn().mockReturnValue(null);
|
||||
jest.mock('~/models/tx', () => ({
|
||||
getMultiplier: mockGetMultiplier,
|
||||
getCacheMultiplier: mockGetCacheMultiplier,
|
||||
}));
|
||||
|
||||
jest.mock('~/server/controllers/agents/callbacks', () => ({
|
||||
createToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
||||
buildSummarizationHandlers: jest.fn().mockReturnValue({}),
|
||||
markSummarizationUsage: jest.fn().mockImplementation((usage) => usage),
|
||||
agentLogHandlerObj: { handle: jest.fn() },
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/PermissionService', () => ({
|
||||
findAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Conversation', () => ({
|
||||
getConvoFiles: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Agent', () => ({
|
||||
getAgent: jest.fn().mockResolvedValue({
|
||||
id: 'agent-123',
|
||||
provider: 'openAI',
|
||||
model_parameters: { model: 'gpt-4' },
|
||||
}),
|
||||
getAgents: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
const mockUpdateBalance = jest.fn().mockResolvedValue({});
|
||||
const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined);
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getAgent: jest.fn().mockResolvedValue({ id: 'agent-123', name: 'Test Agent' }),
|
||||
getFiles: jest.fn(),
|
||||
getUserKey: jest.fn(),
|
||||
getMessages: jest.fn(),
|
||||
|
|
@ -123,6 +106,12 @@ jest.mock('~/models', () => ({
|
|||
getCodeGeneratedFiles: jest.fn(),
|
||||
updateBalance: mockUpdateBalance,
|
||||
bulkInsertTransactions: mockBulkInsertTransactions,
|
||||
spendTokens: mockSpendTokens,
|
||||
spendStructuredTokens: mockSpendStructuredTokens,
|
||||
getMultiplier: mockGetMultiplier,
|
||||
getCacheMultiplier: mockGetCacheMultiplier,
|
||||
getConvoFiles: jest.fn().mockResolvedValue([]),
|
||||
getConvo: jest.fn().mockResolvedValue(null),
|
||||
}));
|
||||
|
||||
describe('OpenAIChatCompletionController', () => {
|
||||
|
|
@ -160,6 +149,77 @@ describe('OpenAIChatCompletionController', () => {
|
|||
};
|
||||
});
|
||||
|
||||
describe('conversation ownership validation', () => {
|
||||
it('should skip ownership check when conversation_id is not provided', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
await OpenAIChatCompletionController(req, res);
|
||||
expect(getConvo).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return 400 when conversation_id is not a string', async () => {
|
||||
const { validateRequest } = require('@librechat/api');
|
||||
validateRequest.mockReturnValueOnce({
|
||||
request: { model: 'agent-123', messages: [], stream: false, conversation_id: { $gt: '' } },
|
||||
});
|
||||
|
||||
await OpenAIChatCompletionController(req, res);
|
||||
expect(res.status).toHaveBeenCalledWith(400);
|
||||
});
|
||||
|
||||
it('should return 404 when conversation is not owned by user', async () => {
|
||||
const { validateRequest } = require('@librechat/api');
|
||||
const { getConvo } = require('~/models');
|
||||
validateRequest.mockReturnValueOnce({
|
||||
request: {
|
||||
model: 'agent-123',
|
||||
messages: [],
|
||||
stream: false,
|
||||
conversation_id: 'convo-abc',
|
||||
},
|
||||
});
|
||||
getConvo.mockResolvedValueOnce(null);
|
||||
|
||||
await OpenAIChatCompletionController(req, res);
|
||||
expect(getConvo).toHaveBeenCalledWith('user-123', 'convo-abc');
|
||||
expect(res.status).toHaveBeenCalledWith(404);
|
||||
});
|
||||
|
||||
it('should proceed when conversation is owned by user', async () => {
|
||||
const { validateRequest } = require('@librechat/api');
|
||||
const { getConvo } = require('~/models');
|
||||
validateRequest.mockReturnValueOnce({
|
||||
request: {
|
||||
model: 'agent-123',
|
||||
messages: [],
|
||||
stream: false,
|
||||
conversation_id: 'convo-abc',
|
||||
},
|
||||
});
|
||||
getConvo.mockResolvedValueOnce({ conversationId: 'convo-abc', user: 'user-123' });
|
||||
|
||||
await OpenAIChatCompletionController(req, res);
|
||||
expect(getConvo).toHaveBeenCalledWith('user-123', 'convo-abc');
|
||||
expect(res.status).not.toHaveBeenCalledWith(404);
|
||||
});
|
||||
|
||||
it('should return 500 when getConvo throws a DB error', async () => {
|
||||
const { validateRequest } = require('@librechat/api');
|
||||
const { getConvo } = require('~/models');
|
||||
validateRequest.mockReturnValueOnce({
|
||||
request: {
|
||||
model: 'agent-123',
|
||||
messages: [],
|
||||
stream: false,
|
||||
conversation_id: 'convo-abc',
|
||||
},
|
||||
});
|
||||
getConvo.mockRejectedValueOnce(new Error('DB connection failed'));
|
||||
|
||||
await OpenAIChatCompletionController(req, res);
|
||||
expect(res.status).toHaveBeenCalledWith(500);
|
||||
});
|
||||
});
|
||||
|
||||
describe('token usage recording', () => {
|
||||
it('should call recordCollectedUsage after successful non-streaming completion', async () => {
|
||||
await OpenAIChatCompletionController(req, res);
|
||||
|
|
|
|||
|
|
@ -101,46 +101,33 @@ jest.mock('~/server/services/ToolService', () => ({
|
|||
loadToolsForExecution: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/spendTokens', () => ({
|
||||
spendTokens: mockSpendTokens,
|
||||
spendStructuredTokens: mockSpendStructuredTokens,
|
||||
}));
|
||||
|
||||
const mockGetMultiplier = jest.fn().mockReturnValue(1);
|
||||
const mockGetCacheMultiplier = jest.fn().mockReturnValue(null);
|
||||
jest.mock('~/models/tx', () => ({
|
||||
getMultiplier: mockGetMultiplier,
|
||||
getCacheMultiplier: mockGetCacheMultiplier,
|
||||
}));
|
||||
|
||||
jest.mock('~/server/controllers/agents/callbacks', () => ({
|
||||
createToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
||||
createResponsesToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
||||
}));
|
||||
jest.mock('~/server/controllers/agents/callbacks', () => {
|
||||
const noop = { handle: jest.fn() };
|
||||
return {
|
||||
createToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
||||
createResponsesToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
||||
markSummarizationUsage: jest.fn().mockImplementation((usage) => usage),
|
||||
agentLogHandlerObj: noop,
|
||||
buildSummarizationHandlers: jest.fn().mockReturnValue({
|
||||
on_summarize_start: noop,
|
||||
on_summarize_delta: noop,
|
||||
on_summarize_complete: noop,
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('~/server/services/PermissionService', () => ({
|
||||
findAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Conversation', () => ({
|
||||
getConvoFiles: jest.fn().mockResolvedValue([]),
|
||||
saveConvo: jest.fn().mockResolvedValue({}),
|
||||
getConvo: jest.fn().mockResolvedValue(null),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Agent', () => ({
|
||||
getAgent: jest.fn().mockResolvedValue({
|
||||
id: 'agent-123',
|
||||
name: 'Test Agent',
|
||||
provider: 'anthropic',
|
||||
model_parameters: { model: 'claude-3' },
|
||||
}),
|
||||
getAgents: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
const mockUpdateBalance = jest.fn().mockResolvedValue({});
|
||||
const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined);
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getAgent: jest.fn().mockResolvedValue({ id: 'agent-123', name: 'Test Agent' }),
|
||||
getFiles: jest.fn(),
|
||||
getUserKey: jest.fn(),
|
||||
getMessages: jest.fn().mockResolvedValue([]),
|
||||
|
|
@ -152,6 +139,13 @@ jest.mock('~/models', () => ({
|
|||
getCodeGeneratedFiles: jest.fn(),
|
||||
updateBalance: mockUpdateBalance,
|
||||
bulkInsertTransactions: mockBulkInsertTransactions,
|
||||
spendTokens: mockSpendTokens,
|
||||
spendStructuredTokens: mockSpendStructuredTokens,
|
||||
getMultiplier: mockGetMultiplier,
|
||||
getCacheMultiplier: mockGetCacheMultiplier,
|
||||
getConvoFiles: jest.fn().mockResolvedValue([]),
|
||||
saveConvo: jest.fn().mockResolvedValue({}),
|
||||
getConvo: jest.fn().mockResolvedValue(null),
|
||||
}));
|
||||
|
||||
describe('createResponse controller', () => {
|
||||
|
|
@ -189,6 +183,102 @@ describe('createResponse controller', () => {
|
|||
};
|
||||
});
|
||||
|
||||
describe('conversation ownership validation', () => {
|
||||
it('should skip ownership check when previous_response_id is not provided', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
await createResponse(req, res);
|
||||
expect(getConvo).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return 400 when previous_response_id is not a string', async () => {
|
||||
const { validateResponseRequest, sendResponsesErrorResponse } = require('@librechat/api');
|
||||
validateResponseRequest.mockReturnValueOnce({
|
||||
request: {
|
||||
model: 'agent-123',
|
||||
input: 'Hello',
|
||||
stream: false,
|
||||
previous_response_id: { $gt: '' },
|
||||
},
|
||||
});
|
||||
|
||||
await createResponse(req, res);
|
||||
expect(sendResponsesErrorResponse).toHaveBeenCalledWith(
|
||||
res,
|
||||
400,
|
||||
'previous_response_id must be a string',
|
||||
'invalid_request',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return 404 when conversation is not owned by user', async () => {
|
||||
const { validateResponseRequest, sendResponsesErrorResponse } = require('@librechat/api');
|
||||
const { getConvo } = require('~/models');
|
||||
validateResponseRequest.mockReturnValueOnce({
|
||||
request: {
|
||||
model: 'agent-123',
|
||||
input: 'Hello',
|
||||
stream: false,
|
||||
previous_response_id: 'resp_abc',
|
||||
},
|
||||
});
|
||||
getConvo.mockResolvedValueOnce(null);
|
||||
|
||||
await createResponse(req, res);
|
||||
expect(getConvo).toHaveBeenCalledWith('user-123', 'resp_abc');
|
||||
expect(sendResponsesErrorResponse).toHaveBeenCalledWith(
|
||||
res,
|
||||
404,
|
||||
'Conversation not found',
|
||||
'not_found',
|
||||
);
|
||||
});
|
||||
|
||||
it('should proceed when conversation is owned by user', async () => {
|
||||
const { validateResponseRequest, sendResponsesErrorResponse } = require('@librechat/api');
|
||||
const { getConvo } = require('~/models');
|
||||
validateResponseRequest.mockReturnValueOnce({
|
||||
request: {
|
||||
model: 'agent-123',
|
||||
input: 'Hello',
|
||||
stream: false,
|
||||
previous_response_id: 'resp_abc',
|
||||
},
|
||||
});
|
||||
getConvo.mockResolvedValueOnce({ conversationId: 'resp_abc', user: 'user-123' });
|
||||
|
||||
await createResponse(req, res);
|
||||
expect(getConvo).toHaveBeenCalledWith('user-123', 'resp_abc');
|
||||
expect(sendResponsesErrorResponse).not.toHaveBeenCalledWith(
|
||||
res,
|
||||
404,
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return 500 when getConvo throws a DB error', async () => {
|
||||
const { validateResponseRequest, sendResponsesErrorResponse } = require('@librechat/api');
|
||||
const { getConvo } = require('~/models');
|
||||
validateResponseRequest.mockReturnValueOnce({
|
||||
request: {
|
||||
model: 'agent-123',
|
||||
input: 'Hello',
|
||||
stream: false,
|
||||
previous_response_id: 'resp_abc',
|
||||
},
|
||||
});
|
||||
getConvo.mockRejectedValueOnce(new Error('DB connection failed'));
|
||||
|
||||
await createResponse(req, res);
|
||||
expect(sendResponsesErrorResponse).toHaveBeenCalledWith(
|
||||
res,
|
||||
500,
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('token usage recording - non-streaming', () => {
|
||||
it('should call recordCollectedUsage after successful non-streaming completion', async () => {
|
||||
await createResponse(req, res);
|
||||
|
|
@ -290,28 +380,7 @@ describe('createResponse controller', () => {
|
|||
it('should collect usage from on_chat_model_end events', async () => {
|
||||
const api = require('@librechat/api');
|
||||
|
||||
let capturedOnChatModelEnd;
|
||||
api.createAggregatorEventHandlers.mockImplementation(() => {
|
||||
return {
|
||||
on_message_delta: { handle: jest.fn() },
|
||||
on_reasoning_delta: { handle: jest.fn() },
|
||||
on_run_step: { handle: jest.fn() },
|
||||
on_run_step_delta: { handle: jest.fn() },
|
||||
on_chat_model_end: {
|
||||
handle: jest.fn((event, data) => {
|
||||
if (capturedOnChatModelEnd) {
|
||||
capturedOnChatModelEnd(event, data);
|
||||
}
|
||||
}),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
api.createRun.mockImplementation(async ({ customHandlers }) => {
|
||||
capturedOnChatModelEnd = (event, data) => {
|
||||
customHandlers.on_chat_model_end.handle(event, data);
|
||||
};
|
||||
|
||||
return {
|
||||
processStream: jest.fn().mockImplementation(async () => {
|
||||
customHandlers.on_chat_model_end.handle('on_chat_model_end', {
|
||||
|
|
@ -328,7 +397,6 @@ describe('createResponse controller', () => {
|
|||
});
|
||||
|
||||
await createResponse(req, res);
|
||||
|
||||
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
|
|
|
|||
|
|
@ -0,0 +1,159 @@
|
|||
jest.mock('~/server/services/PermissionService', () => ({
|
||||
findPubliclyAccessibleResources: jest.fn(),
|
||||
findAccessibleResources: jest.fn(),
|
||||
hasPublicPermission: jest.fn(),
|
||||
grantPermission: jest.fn().mockResolvedValue({}),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCachedTools: jest.fn(),
|
||||
getMCPServerTools: jest.fn(),
|
||||
}));
|
||||
|
||||
const mongoose = require('mongoose');
|
||||
const { actionDelimiter } = require('librechat-data-provider');
|
||||
const { agentSchema, actionSchema } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { duplicateAgent } = require('../v1');
|
||||
|
||||
let mongoServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
if (!mongoose.models.Agent) {
|
||||
mongoose.model('Agent', agentSchema);
|
||||
}
|
||||
if (!mongoose.models.Action) {
|
||||
mongoose.model('Action', actionSchema);
|
||||
}
|
||||
await mongoose.connect(mongoUri);
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await mongoose.models.Agent.deleteMany({});
|
||||
await mongoose.models.Action.deleteMany({});
|
||||
});
|
||||
|
||||
describe('duplicateAgentHandler — action domain extraction', () => {
|
||||
it('builds duplicated action entries using metadata.domain, not action_id', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const originalAgentId = `agent_original`;
|
||||
|
||||
const agent = await mongoose.models.Agent.create({
|
||||
id: originalAgentId,
|
||||
name: 'Test Agent',
|
||||
author: userId.toString(),
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: [],
|
||||
actions: [`api.example.com${actionDelimiter}act_original`],
|
||||
versions: [{ name: 'Test Agent', createdAt: new Date(), updatedAt: new Date() }],
|
||||
});
|
||||
|
||||
await mongoose.models.Action.create({
|
||||
user: userId,
|
||||
action_id: 'act_original',
|
||||
agent_id: originalAgentId,
|
||||
metadata: { domain: 'api.example.com' },
|
||||
});
|
||||
|
||||
const req = {
|
||||
params: { id: agent.id },
|
||||
user: { id: userId.toString() },
|
||||
};
|
||||
const res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
};
|
||||
|
||||
await duplicateAgent(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const { agent: newAgent, actions: newActions } = res.json.mock.calls[0][0];
|
||||
|
||||
expect(newAgent.id).not.toBe(originalAgentId);
|
||||
expect(String(newAgent.author)).toBe(userId.toString());
|
||||
expect(newActions).toHaveLength(1);
|
||||
expect(newActions[0].metadata.domain).toBe('api.example.com');
|
||||
expect(newActions[0].agent_id).toBe(newAgent.id);
|
||||
|
||||
for (const actionEntry of newAgent.actions) {
|
||||
const [domain, actionId] = actionEntry.split(actionDelimiter);
|
||||
expect(domain).toBe('api.example.com');
|
||||
expect(actionId).toBeTruthy();
|
||||
expect(actionId).not.toBe('act_original');
|
||||
}
|
||||
|
||||
const allActions = await mongoose.models.Action.find({}).lean();
|
||||
expect(allActions).toHaveLength(2);
|
||||
|
||||
const originalAction = allActions.find((a) => a.action_id === 'act_original');
|
||||
expect(originalAction.agent_id).toBe(originalAgentId);
|
||||
|
||||
const duplicatedAction = allActions.find((a) => a.action_id !== 'act_original');
|
||||
expect(duplicatedAction.agent_id).toBe(newAgent.id);
|
||||
expect(duplicatedAction.metadata.domain).toBe('api.example.com');
|
||||
});
|
||||
|
||||
it('strips sensitive metadata fields from duplicated actions', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const originalAgentId = 'agent_sensitive';
|
||||
|
||||
await mongoose.models.Agent.create({
|
||||
id: originalAgentId,
|
||||
name: 'Sensitive Agent',
|
||||
author: userId.toString(),
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: [],
|
||||
actions: [`secure.api.com${actionDelimiter}act_secret`],
|
||||
versions: [{ name: 'Sensitive Agent', createdAt: new Date(), updatedAt: new Date() }],
|
||||
});
|
||||
|
||||
await mongoose.models.Action.create({
|
||||
user: userId,
|
||||
action_id: 'act_secret',
|
||||
agent_id: originalAgentId,
|
||||
metadata: {
|
||||
domain: 'secure.api.com',
|
||||
api_key: 'sk-secret-key-12345',
|
||||
oauth_client_id: 'client_id_xyz',
|
||||
oauth_client_secret: 'client_secret_xyz',
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
params: { id: originalAgentId },
|
||||
user: { id: userId.toString() },
|
||||
};
|
||||
const res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
};
|
||||
|
||||
await duplicateAgent(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const duplicatedAction = await mongoose.models.Action.findOne({
|
||||
agent_id: { $ne: originalAgentId },
|
||||
}).lean();
|
||||
|
||||
expect(duplicatedAction.metadata.domain).toBe('secure.api.com');
|
||||
expect(duplicatedAction.metadata.api_key).toBeUndefined();
|
||||
expect(duplicatedAction.metadata.oauth_client_id).toBeUndefined();
|
||||
expect(duplicatedAction.metadata.oauth_client_secret).toBeUndefined();
|
||||
|
||||
const originalAction = await mongoose.models.Action.findOne({
|
||||
action_id: 'act_secret',
|
||||
}).lean();
|
||||
expect(originalAction.metadata.api_key).toBe('sk-secret-key-12345');
|
||||
});
|
||||
});
|
||||
|
|
@ -1,10 +1,8 @@
|
|||
const { duplicateAgent } = require('../v1');
|
||||
const { getAgent, createAgent } = require('~/models/Agent');
|
||||
const { getActions } = require('~/models/Action');
|
||||
const { getAgent, createAgent, getActions } = require('~/models');
|
||||
const { nanoid } = require('nanoid');
|
||||
|
||||
jest.mock('~/models/Agent');
|
||||
jest.mock('~/models/Action');
|
||||
jest.mock('~/models');
|
||||
jest.mock('nanoid');
|
||||
|
||||
describe('duplicateAgent', () => {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
const { nanoid } = require('nanoid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Constants, EnvVar, GraphEvents, ToolEndHandler } = require('@librechat/agents');
|
||||
const { Tools, StepTypes, FileContext, ErrorTypes } = require('librechat-data-provider');
|
||||
const {
|
||||
EnvVar,
|
||||
Constants,
|
||||
GraphEvents,
|
||||
GraphNodeKeys,
|
||||
ToolEndHandler,
|
||||
} = require('@librechat/agents');
|
||||
const {
|
||||
sendEvent,
|
||||
GenerationJobManager,
|
||||
|
|
@ -71,7 +77,9 @@ class ModelEndHandler {
|
|||
usage.model = modelName;
|
||||
}
|
||||
|
||||
this.collectedUsage.push(usage);
|
||||
const taggedUsage = markSummarizationUsage(usage, metadata);
|
||||
|
||||
this.collectedUsage.push(taggedUsage);
|
||||
} catch (error) {
|
||||
logger.error('Error handling model end event:', error);
|
||||
return this.finalize(errorMessage);
|
||||
|
|
@ -133,6 +141,7 @@ function getDefaultHandlers({
|
|||
collectedUsage,
|
||||
streamId = null,
|
||||
toolExecuteOptions = null,
|
||||
summarizationOptions = null,
|
||||
}) {
|
||||
if (!res || !aggregateContent) {
|
||||
throw new Error(
|
||||
|
|
@ -245,6 +254,37 @@ function getDefaultHandlers({
|
|||
handlers[GraphEvents.ON_TOOL_EXECUTE] = createToolExecuteHandler(toolExecuteOptions);
|
||||
}
|
||||
|
||||
if (summarizationOptions?.enabled !== false) {
|
||||
handlers[GraphEvents.ON_SUMMARIZE_START] = {
|
||||
handle: async (_event, data) => {
|
||||
await emitEvent(res, streamId, {
|
||||
event: GraphEvents.ON_SUMMARIZE_START,
|
||||
data,
|
||||
});
|
||||
},
|
||||
};
|
||||
handlers[GraphEvents.ON_SUMMARIZE_DELTA] = {
|
||||
handle: async (_event, data) => {
|
||||
aggregateContent({ event: GraphEvents.ON_SUMMARIZE_DELTA, data });
|
||||
await emitEvent(res, streamId, {
|
||||
event: GraphEvents.ON_SUMMARIZE_DELTA,
|
||||
data,
|
||||
});
|
||||
},
|
||||
};
|
||||
handlers[GraphEvents.ON_SUMMARIZE_COMPLETE] = {
|
||||
handle: async (_event, data) => {
|
||||
aggregateContent({ event: GraphEvents.ON_SUMMARIZE_COMPLETE, data });
|
||||
await emitEvent(res, streamId, {
|
||||
event: GraphEvents.ON_SUMMARIZE_COMPLETE,
|
||||
data,
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
handlers[GraphEvents.ON_AGENT_LOG] = { handle: agentLogHandler };
|
||||
|
||||
return handlers;
|
||||
}
|
||||
|
||||
|
|
@ -668,8 +708,62 @@ function createResponsesToolEndCallback({ req, res, tracker, artifactPromises })
|
|||
};
|
||||
}
|
||||
|
||||
const ALLOWED_LOG_LEVELS = new Set(['debug', 'info', 'warn', 'error']);
|
||||
|
||||
function agentLogHandler(_event, data) {
|
||||
if (!data) {
|
||||
return;
|
||||
}
|
||||
const logFn = ALLOWED_LOG_LEVELS.has(data.level) ? logger[data.level] : logger.debug;
|
||||
const meta = typeof data.data === 'object' && data.data != null ? data.data : {};
|
||||
logFn(`[agents:${data.scope ?? 'unknown'}] ${data.message ?? ''}`, {
|
||||
...meta,
|
||||
runId: data.runId,
|
||||
agentId: data.agentId,
|
||||
});
|
||||
}
|
||||
|
||||
function markSummarizationUsage(usage, metadata) {
|
||||
const node = metadata?.langgraph_node;
|
||||
if (typeof node === 'string' && node.startsWith(GraphNodeKeys.SUMMARIZE)) {
|
||||
return { ...usage, usage_type: 'summarization' };
|
||||
}
|
||||
return usage;
|
||||
}
|
||||
|
||||
const agentLogHandlerObj = { handle: agentLogHandler };
|
||||
|
||||
/**
|
||||
* Builds the three summarization SSE event handlers.
|
||||
* In streaming mode, each event is forwarded to the client via `res.write`.
|
||||
* In non-streaming mode, the handlers are no-ops.
|
||||
* @param {{ isStreaming: boolean, res: import('express').Response }} opts
|
||||
*/
|
||||
function buildSummarizationHandlers({ isStreaming, res }) {
|
||||
if (!isStreaming) {
|
||||
const noop = { handle: () => {} };
|
||||
return { on_summarize_start: noop, on_summarize_delta: noop, on_summarize_complete: noop };
|
||||
}
|
||||
const writeEvent = (name) => ({
|
||||
handle: async (_event, data) => {
|
||||
if (!res.writableEnded) {
|
||||
res.write(`event: ${name}\ndata: ${JSON.stringify(data)}\n\n`);
|
||||
}
|
||||
},
|
||||
});
|
||||
return {
|
||||
on_summarize_start: writeEvent('on_summarize_start'),
|
||||
on_summarize_delta: writeEvent('on_summarize_delta'),
|
||||
on_summarize_complete: writeEvent('on_summarize_complete'),
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
agentLogHandler,
|
||||
agentLogHandlerObj,
|
||||
getDefaultHandlers,
|
||||
createToolEndCallback,
|
||||
markSummarizationUsage,
|
||||
buildSummarizationHandlers,
|
||||
createResponsesToolEndCallback,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ const { logger } = require('@librechat/data-schemas');
|
|||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||
const {
|
||||
createRun,
|
||||
Tokenizer,
|
||||
isEnabled,
|
||||
checkAccess,
|
||||
buildToolSet,
|
||||
sanitizeTitle,
|
||||
logToolError,
|
||||
sanitizeTitle,
|
||||
payloadParser,
|
||||
resolveHeaders,
|
||||
createSafeUser,
|
||||
|
|
@ -22,8 +22,11 @@ const {
|
|||
GenerationJobManager,
|
||||
getTransactionsConfig,
|
||||
createMemoryProcessor,
|
||||
loadAgent: loadAgentFn,
|
||||
createMultiAgentMapper,
|
||||
filterMalformedContentParts,
|
||||
countFormattedMessageTokens,
|
||||
hydrateMissingIndexTokenCounts,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Callback,
|
||||
|
|
@ -44,18 +47,16 @@ const {
|
|||
isEphemeralAgentId,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { updateBalance, bulkInsertTransactions } = require('~/models');
|
||||
const { getMultiplier, getCacheMultiplier } = require('~/models/tx');
|
||||
const { createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getMCPServerTools } = require('~/server/services/Config');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { loadAgent } = require('~/models/Agent');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const db = require('~/models');
|
||||
|
||||
const loadAgent = (params) => loadAgentFn(params, { getAgent: db.getAgent, getMCPServerTools });
|
||||
|
||||
class AgentClient extends BaseClient {
|
||||
constructor(options = {}) {
|
||||
super(null, options);
|
||||
|
|
@ -63,9 +64,6 @@ class AgentClient extends BaseClient {
|
|||
* @type {string} */
|
||||
this.clientName = EModelEndpoint.agents;
|
||||
|
||||
/** @type {'discard' | 'summarize'} */
|
||||
this.contextStrategy = 'discard';
|
||||
|
||||
/** @deprecated @type {true} - Is a Chat Completion Request */
|
||||
this.isChatCompletion = true;
|
||||
|
||||
|
|
@ -217,7 +215,6 @@ class AgentClient extends BaseClient {
|
|||
}))
|
||||
: []),
|
||||
];
|
||||
|
||||
if (this.options.attachments) {
|
||||
const attachments = await this.options.attachments;
|
||||
const latestMessage = orderedMessages[orderedMessages.length - 1];
|
||||
|
|
@ -244,6 +241,11 @@ class AgentClient extends BaseClient {
|
|||
);
|
||||
}
|
||||
|
||||
/** @type {Record<number, number>} */
|
||||
const canonicalTokenCountMap = {};
|
||||
/** @type {Record<string, number>} */
|
||||
const tokenCountMap = {};
|
||||
let promptTokenTotal = 0;
|
||||
const formattedMessages = orderedMessages.map((message, i) => {
|
||||
const formattedMessage = formatMessage({
|
||||
message,
|
||||
|
|
@ -263,12 +265,14 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
}
|
||||
|
||||
const needsTokenCount =
|
||||
(this.contextStrategy && !orderedMessages[i].tokenCount) || message.fileContext;
|
||||
const dbTokenCount = orderedMessages[i].tokenCount;
|
||||
const needsTokenCount = !dbTokenCount || message.fileContext;
|
||||
|
||||
/* If tokens were never counted, or, is a Vision request and the message has files, count again */
|
||||
if (needsTokenCount || (this.isVisionModel && (message.image_urls || message.files))) {
|
||||
orderedMessages[i].tokenCount = this.getTokenCountForMessage(formattedMessage);
|
||||
orderedMessages[i].tokenCount = countFormattedMessageTokens(
|
||||
formattedMessage,
|
||||
this.getEncoding(),
|
||||
);
|
||||
}
|
||||
|
||||
/* If message has files, calculate image token cost */
|
||||
|
|
@ -282,17 +286,37 @@ class AgentClient extends BaseClient {
|
|||
if (file.metadata?.fileIdentifier) {
|
||||
continue;
|
||||
}
|
||||
// orderedMessages[i].tokenCount += this.calculateImageTokenCost({
|
||||
// width: file.width,
|
||||
// height: file.height,
|
||||
// detail: this.options.imageDetail ?? ImageDetail.auto,
|
||||
// });
|
||||
}
|
||||
}
|
||||
|
||||
const tokenCount = Number(orderedMessages[i].tokenCount);
|
||||
const normalizedTokenCount = Number.isFinite(tokenCount) && tokenCount > 0 ? tokenCount : 0;
|
||||
canonicalTokenCountMap[i] = normalizedTokenCount;
|
||||
promptTokenTotal += normalizedTokenCount;
|
||||
|
||||
if (message.messageId) {
|
||||
tokenCountMap[message.messageId] = normalizedTokenCount;
|
||||
}
|
||||
|
||||
if (isEnabled(process.env.AGENT_DEBUG_LOGGING)) {
|
||||
const role = message.isCreatedByUser ? 'user' : 'assistant';
|
||||
const hasSummary =
|
||||
Array.isArray(message.content) && message.content.some((p) => p && p.type === 'summary');
|
||||
const suffix = hasSummary ? '[S]' : '';
|
||||
const id = (message.messageId ?? message.id ?? '').slice(-8);
|
||||
const recalced = needsTokenCount ? orderedMessages[i].tokenCount : null;
|
||||
logger.debug(
|
||||
`[AgentClient] msg[${i}] ${role}${suffix} id=…${id} db=${dbTokenCount} needsRecount=${needsTokenCount} recalced=${recalced} tokens=${normalizedTokenCount}`,
|
||||
);
|
||||
}
|
||||
|
||||
return formattedMessage;
|
||||
});
|
||||
|
||||
payload = formattedMessages;
|
||||
messages = orderedMessages;
|
||||
promptTokens = promptTokenTotal;
|
||||
|
||||
/**
|
||||
* Build shared run context - applies to ALL agents in the run.
|
||||
* This includes: file context (latest message), augmented prompt (RAG), memory context.
|
||||
|
|
@ -322,23 +346,20 @@ class AgentClient extends BaseClient {
|
|||
|
||||
const sharedRunContext = sharedRunContextParts.join('\n\n');
|
||||
|
||||
/** @type {Record<string, number> | undefined} */
|
||||
let tokenCountMap;
|
||||
/** Preserve canonical pre-format token counts for all history entering graph formatting */
|
||||
this.indexTokenCountMap = canonicalTokenCountMap;
|
||||
|
||||
if (this.contextStrategy) {
|
||||
({ payload, promptTokens, tokenCountMap, messages } = await this.handleContextStrategy({
|
||||
orderedMessages,
|
||||
formattedMessages,
|
||||
}));
|
||||
}
|
||||
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
this.indexTokenCountMap[i] = messages[i].tokenCount;
|
||||
/** Extract contextMeta from the parent response (second-to-last in ordered chain;
|
||||
* last is the current user message). Seeds the pruner's calibration EMA for this run. */
|
||||
const parentResponse =
|
||||
orderedMessages.length >= 2 ? orderedMessages[orderedMessages.length - 2] : undefined;
|
||||
if (parentResponse?.contextMeta && !parentResponse.isCreatedByUser) {
|
||||
this.contextMeta = parentResponse.contextMeta;
|
||||
}
|
||||
|
||||
const result = {
|
||||
tokenCountMap,
|
||||
prompt: payload,
|
||||
tokenCountMap,
|
||||
promptTokens,
|
||||
messages,
|
||||
};
|
||||
|
|
@ -412,7 +433,7 @@ class AgentClient extends BaseClient {
|
|||
user,
|
||||
permissionType: PermissionTypes.MEMORIES,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
getRoleByName: db.getRoleByName,
|
||||
});
|
||||
|
||||
if (!hasAccess) {
|
||||
|
|
@ -472,13 +493,14 @@ class AgentClient extends BaseClient {
|
|||
},
|
||||
},
|
||||
{
|
||||
getConvoFiles,
|
||||
getFiles: db.getFiles,
|
||||
getUserKey: db.getUserKey,
|
||||
getConvoFiles: db.getConvoFiles,
|
||||
updateFilesUsage: db.updateFilesUsage,
|
||||
getUserKeyValues: db.getUserKeyValues,
|
||||
getToolFilesByIds: db.getToolFilesByIds,
|
||||
getCodeGeneratedFiles: db.getCodeGeneratedFiles,
|
||||
filterFilesByAgentAccess,
|
||||
},
|
||||
);
|
||||
|
||||
|
|
@ -629,10 +651,10 @@ class AgentClient extends BaseClient {
|
|||
}) {
|
||||
const result = await recordCollectedUsage(
|
||||
{
|
||||
spendTokens,
|
||||
spendStructuredTokens,
|
||||
pricing: { getMultiplier, getCacheMultiplier },
|
||||
bulkWriteOps: { insertMany: bulkInsertTransactions, updateBalance },
|
||||
spendTokens: db.spendTokens,
|
||||
spendStructuredTokens: db.spendStructuredTokens,
|
||||
pricing: { getMultiplier: db.getMultiplier, getCacheMultiplier: db.getCacheMultiplier },
|
||||
bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance },
|
||||
},
|
||||
{
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
|
|
@ -665,39 +687,7 @@ class AgentClient extends BaseClient {
|
|||
* @returns {number}
|
||||
*/
|
||||
getTokenCountForResponse({ content }) {
|
||||
return this.getTokenCountForMessage({
|
||||
role: 'assistant',
|
||||
content,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the correct token count for the current user message based on the token count map and API usage.
|
||||
* Edge case: If the calculation results in a negative value, it returns the original estimate.
|
||||
* If revisiting a conversation with a chat history entirely composed of token estimates,
|
||||
* the cumulative token count going forward should become more accurate as the conversation progresses.
|
||||
* @param {Object} params - The parameters for the calculation.
|
||||
* @param {Record<string, number>} params.tokenCountMap - A map of message IDs to their token counts.
|
||||
* @param {string} params.currentMessageId - The ID of the current message to calculate.
|
||||
* @param {OpenAIUsageMetadata} params.usage - The usage object returned by the API.
|
||||
* @returns {number} The correct token count for the current user message.
|
||||
*/
|
||||
calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) {
|
||||
const originalEstimate = tokenCountMap[currentMessageId] || 0;
|
||||
|
||||
if (!usage || typeof usage[this.inputTokensKey] !== 'number') {
|
||||
return originalEstimate;
|
||||
}
|
||||
|
||||
tokenCountMap[currentMessageId] = 0;
|
||||
const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => {
|
||||
const numCount = Number(count);
|
||||
return sum + (isNaN(numCount) ? 0 : numCount);
|
||||
}, 0);
|
||||
const totalInputTokens = usage[this.inputTokensKey] ?? 0;
|
||||
|
||||
const currentMessageTokens = totalInputTokens - totalTokensFromMap;
|
||||
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
|
||||
return countFormattedMessageTokens({ role: 'assistant', content }, this.getEncoding());
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -745,11 +735,34 @@ class AgentClient extends BaseClient {
|
|||
};
|
||||
|
||||
const toolSet = buildToolSet(this.options.agent);
|
||||
let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages(
|
||||
payload,
|
||||
this.indexTokenCountMap,
|
||||
toolSet,
|
||||
);
|
||||
const tokenCounter = createTokenCounter(this.getEncoding());
|
||||
let {
|
||||
messages: initialMessages,
|
||||
indexTokenCountMap,
|
||||
summary: initialSummary,
|
||||
boundaryTokenAdjustment,
|
||||
} = formatAgentMessages(payload, this.indexTokenCountMap, toolSet);
|
||||
if (boundaryTokenAdjustment) {
|
||||
logger.debug(
|
||||
`[AgentClient] Boundary token adjustment: ${boundaryTokenAdjustment.original} → ${boundaryTokenAdjustment.adjusted} (${boundaryTokenAdjustment.remainingChars}/${boundaryTokenAdjustment.totalChars} chars)`,
|
||||
);
|
||||
}
|
||||
if (indexTokenCountMap && isEnabled(process.env.AGENT_DEBUG_LOGGING)) {
|
||||
const entries = Object.entries(indexTokenCountMap);
|
||||
const perMsg = entries.map(([idx, count]) => {
|
||||
const msg = initialMessages[Number(idx)];
|
||||
const type = msg ? msg._getType() : '?';
|
||||
return `${idx}:${type}=${count}`;
|
||||
});
|
||||
logger.debug(
|
||||
`[AgentClient] Token map after format: [${perMsg.join(', ')}] (payload=${payload.length}, formatted=${initialMessages.length})`,
|
||||
);
|
||||
}
|
||||
indexTokenCountMap = hydrateMissingIndexTokenCounts({
|
||||
messages: initialMessages,
|
||||
indexTokenCountMap,
|
||||
tokenCounter,
|
||||
});
|
||||
|
||||
/**
|
||||
* @param {BaseMessage[]} messages
|
||||
|
|
@ -803,16 +816,32 @@ class AgentClient extends BaseClient {
|
|||
|
||||
memoryPromise = this.runMemory(messages);
|
||||
|
||||
/** Seed calibration state from previous run if encoding matches */
|
||||
const currentEncoding = this.getEncoding();
|
||||
const prevMeta = this.contextMeta;
|
||||
const encodingMatch = prevMeta?.encoding === currentEncoding;
|
||||
const calibrationRatio =
|
||||
encodingMatch && prevMeta?.calibrationRatio > 0 ? prevMeta.calibrationRatio : undefined;
|
||||
|
||||
if (prevMeta) {
|
||||
logger.debug(
|
||||
`[AgentClient] contextMeta from parent: ratio=${prevMeta.calibrationRatio}, encoding=${prevMeta.encoding}, current=${currentEncoding}, seeded=${calibrationRatio ?? 'none'}`,
|
||||
);
|
||||
}
|
||||
|
||||
run = await createRun({
|
||||
agents,
|
||||
messages,
|
||||
indexTokenCountMap,
|
||||
initialSummary,
|
||||
calibrationRatio,
|
||||
runId: this.responseMessageId,
|
||||
signal: abortController.signal,
|
||||
customHandlers: this.options.eventHandlers,
|
||||
requestBody: config.configurable.requestBody,
|
||||
user: createSafeUser(this.options.req?.user),
|
||||
tokenCounter: createTokenCounter(this.getEncoding()),
|
||||
summarizationConfig: appConfig?.summarization,
|
||||
tokenCounter,
|
||||
});
|
||||
|
||||
if (!run) {
|
||||
|
|
@ -843,6 +872,7 @@ class AgentClient extends BaseClient {
|
|||
|
||||
const hideSequentialOutputs = config.configurable.hide_sequential_outputs;
|
||||
await runAgents(initialMessages);
|
||||
|
||||
/** @deprecated Agent Chain */
|
||||
if (hideSequentialOutputs) {
|
||||
this.contentParts = this.contentParts.filter((part, index) => {
|
||||
|
|
@ -873,6 +903,18 @@ class AgentClient extends BaseClient {
|
|||
});
|
||||
}
|
||||
} finally {
|
||||
/** Capture calibration state from the run for persistence on the response message.
|
||||
* Runs in finally so values are captured even on abort. */
|
||||
const ratio = this.run?.getCalibrationRatio() ?? 0;
|
||||
if (ratio > 0 && ratio !== 1) {
|
||||
this.contextMeta = {
|
||||
calibrationRatio: Math.round(ratio * 1000) / 1000,
|
||||
encoding: this.getEncoding(),
|
||||
};
|
||||
} else {
|
||||
this.contextMeta = undefined;
|
||||
}
|
||||
|
||||
try {
|
||||
const attachments = await this.awaitMemoryWithTimeout(memoryPromise);
|
||||
if (attachments && attachments.length > 0) {
|
||||
|
|
@ -1058,6 +1100,7 @@ class AgentClient extends BaseClient {
|
|||
titlePrompt: endpointConfig?.titlePrompt,
|
||||
titlePromptTemplate: endpointConfig?.titlePromptTemplate,
|
||||
chainOptions: {
|
||||
runName: 'TitleRun',
|
||||
signal: abortController.signal,
|
||||
callbacks: [
|
||||
{
|
||||
|
|
@ -1132,7 +1175,7 @@ class AgentClient extends BaseClient {
|
|||
context = 'message',
|
||||
}) {
|
||||
try {
|
||||
await spendTokens(
|
||||
await db.spendTokens(
|
||||
{
|
||||
model,
|
||||
context,
|
||||
|
|
@ -1151,7 +1194,7 @@ class AgentClient extends BaseClient {
|
|||
'reasoning_tokens' in usage &&
|
||||
typeof usage.reasoning_tokens === 'number'
|
||||
) {
|
||||
await spendTokens(
|
||||
await db.spendTokens(
|
||||
{
|
||||
model,
|
||||
balance,
|
||||
|
|
@ -1172,19 +1215,13 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
}
|
||||
|
||||
/** Anthropic Claude models use a distinct BPE tokenizer; all others default to o200k_base. */
|
||||
getEncoding() {
|
||||
if (this.model && this.model.toLowerCase().includes('claude')) {
|
||||
return 'claude';
|
||||
}
|
||||
return 'o200k_base';
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the token count of a given text. It also checks and resets the tokenizers if necessary.
|
||||
* @param {string} text - The text to get the token count for.
|
||||
* @returns {number} The token count of the given text.
|
||||
*/
|
||||
getTokenCount(text) {
|
||||
const encoding = this.getEncoding();
|
||||
return Tokenizer.getTokenCount(text, encoding);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = AgentClient;
|
||||
|
|
|
|||
|
|
@ -15,13 +15,15 @@ jest.mock('@librechat/api', () => ({
|
|||
checkAccess: jest.fn(),
|
||||
initializeAgent: jest.fn(),
|
||||
createMemoryProcessor: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Agent', () => ({
|
||||
loadAgent: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Role', () => ({
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getMCPServerTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getAgent: jest.fn(),
|
||||
getRoleByName: jest.fn(),
|
||||
}));
|
||||
|
||||
|
|
@ -1816,7 +1818,7 @@ describe('AgentClient - titleConvo', () => {
|
|||
|
||||
/** Traversal stops at msg-2 (has summary), so we get msg-4 -> msg-3 -> msg-2 */
|
||||
expect(result).toHaveLength(3);
|
||||
expect(result[0].text).toBe('Summary of conversation');
|
||||
expect(result[0].content).toEqual([{ type: 'text', text: 'Summary of conversation' }]);
|
||||
expect(result[0].role).toBe('system');
|
||||
expect(result[0].mapped).toBe(true);
|
||||
expect(result[1].mapped).toBe(true);
|
||||
|
|
@ -2138,7 +2140,7 @@ describe('AgentClient - titleConvo', () => {
|
|||
};
|
||||
|
||||
mockCheckAccess = require('@librechat/api').checkAccess;
|
||||
mockLoadAgent = require('~/models/Agent').loadAgent;
|
||||
mockLoadAgent = require('@librechat/api').loadAgent;
|
||||
mockInitializeAgent = require('@librechat/api').initializeAgent;
|
||||
mockCreateMemoryProcessor = require('@librechat/api').createMemoryProcessor;
|
||||
});
|
||||
|
|
@ -2195,6 +2197,7 @@ describe('AgentClient - titleConvo', () => {
|
|||
expect.objectContaining({
|
||||
agent_id: differentAgentId,
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(mockInitializeAgent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ const { logger } = require('@librechat/data-schemas');
|
|||
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
||||
const { sendResponse } = require('~/server/middleware/error');
|
||||
const { recordUsage } = require('~/server/services/Threads');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getConvo } = require('~/models');
|
||||
|
||||
/**
|
||||
* @typedef {Object} ErrorHandlerContext
|
||||
|
|
|
|||
668
api/server/controllers/agents/filterAuthorizedTools.spec.js
Normal file
668
api/server/controllers/agents/filterAuthorizedTools.spec.js
Normal file
|
|
@ -0,0 +1,668 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
const d = Constants.mcp_delimiter;
|
||||
|
||||
const mockGetAllServerConfigs = jest.fn();
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCachedTools: jest.fn().mockResolvedValue({
|
||||
web_search: true,
|
||||
execute_code: true,
|
||||
file_search: true,
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPServersRegistry: jest.fn(() => ({
|
||||
getAllServerConfigs: mockGetAllServerConfigs,
|
||||
})),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/images/avatar', () => ({
|
||||
resizeAvatar: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
filterFile: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/PermissionService', () => ({
|
||||
findAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||
findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||
grantPermission: jest.fn(),
|
||||
hasPublicPermission: jest.fn().mockResolvedValue(false),
|
||||
checkPermission: jest.fn().mockResolvedValue(true),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => {
|
||||
const mongoose = require('mongoose');
|
||||
const { createModels, createMethods } = require('@librechat/data-schemas');
|
||||
createModels(mongoose);
|
||||
const methods = createMethods(mongoose);
|
||||
return {
|
||||
...methods,
|
||||
getCategoriesWithCounts: jest.fn(),
|
||||
deleteFileByFilter: jest.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(() => ({
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
delete: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
const {
|
||||
filterAuthorizedTools,
|
||||
createAgent: createAgentHandler,
|
||||
updateAgent: updateAgentHandler,
|
||||
duplicateAgent: duplicateAgentHandler,
|
||||
revertAgentVersion: revertAgentVersionHandler,
|
||||
} = require('./v1');
|
||||
|
||||
const { getMCPServersRegistry } = require('~/config');
|
||||
|
||||
let Agent;
|
||||
|
||||
describe('MCP Tool Authorization', () => {
|
||||
let mongoServer;
|
||||
let mockReq;
|
||||
let mockRes;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await Agent.deleteMany({});
|
||||
jest.clearAllMocks();
|
||||
|
||||
getMCPServersRegistry.mockImplementation(() => ({
|
||||
getAllServerConfigs: mockGetAllServerConfigs,
|
||||
}));
|
||||
mockGetAllServerConfigs.mockResolvedValue({
|
||||
authorizedServer: { type: 'sse', url: 'https://authorized.example.com' },
|
||||
anotherServer: { type: 'sse', url: 'https://another.example.com' },
|
||||
});
|
||||
|
||||
mockReq = {
|
||||
user: {
|
||||
id: new mongoose.Types.ObjectId().toString(),
|
||||
role: 'USER',
|
||||
},
|
||||
body: {},
|
||||
params: {},
|
||||
query: {},
|
||||
app: { locals: { fileStrategy: 'local' } },
|
||||
};
|
||||
|
||||
mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn().mockReturnThis(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('filterAuthorizedTools', () => {
|
||||
const availableTools = { web_search: true, custom_tool: true };
|
||||
const userId = 'test-user-123';
|
||||
|
||||
test('should keep authorized MCP tools and strip unauthorized ones', async () => {
|
||||
const result = await filterAuthorizedTools({
|
||||
tools: [`toolA${d}authorizedServer`, `toolB${d}forbiddenServer`, 'web_search'],
|
||||
userId,
|
||||
availableTools,
|
||||
});
|
||||
|
||||
expect(result).toContain(`toolA${d}authorizedServer`);
|
||||
expect(result).toContain('web_search');
|
||||
expect(result).not.toContain(`toolB${d}forbiddenServer`);
|
||||
});
|
||||
|
||||
test('should keep system tools without querying MCP registry', async () => {
|
||||
const result = await filterAuthorizedTools({
|
||||
tools: ['execute_code', 'file_search', 'web_search'],
|
||||
userId,
|
||||
availableTools: {},
|
||||
});
|
||||
|
||||
expect(result).toEqual(['execute_code', 'file_search', 'web_search']);
|
||||
expect(mockGetAllServerConfigs).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should not query MCP registry when no MCP tools are present', async () => {
|
||||
const result = await filterAuthorizedTools({
|
||||
tools: ['web_search', 'custom_tool'],
|
||||
userId,
|
||||
availableTools,
|
||||
});
|
||||
|
||||
expect(result).toEqual(['web_search', 'custom_tool']);
|
||||
expect(mockGetAllServerConfigs).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should filter all MCP tools when registry is uninitialized', async () => {
|
||||
getMCPServersRegistry.mockImplementation(() => {
|
||||
throw new Error('MCPServersRegistry has not been initialized.');
|
||||
});
|
||||
|
||||
const result = await filterAuthorizedTools({
|
||||
tools: [`toolA${d}someServer`, 'web_search'],
|
||||
userId,
|
||||
availableTools,
|
||||
});
|
||||
|
||||
expect(result).toEqual(['web_search']);
|
||||
expect(result).not.toContain(`toolA${d}someServer`);
|
||||
});
|
||||
|
||||
test('should handle mixed authorized and unauthorized MCP tools', async () => {
|
||||
const result = await filterAuthorizedTools({
|
||||
tools: [
|
||||
'web_search',
|
||||
`search${d}authorizedServer`,
|
||||
`attack${d}victimServer`,
|
||||
'execute_code',
|
||||
`list${d}anotherServer`,
|
||||
`steal${d}nonexistent`,
|
||||
],
|
||||
userId,
|
||||
availableTools,
|
||||
});
|
||||
|
||||
expect(result).toEqual([
|
||||
'web_search',
|
||||
`search${d}authorizedServer`,
|
||||
'execute_code',
|
||||
`list${d}anotherServer`,
|
||||
]);
|
||||
});
|
||||
|
||||
test('should handle empty tools array', async () => {
|
||||
const result = await filterAuthorizedTools({
|
||||
tools: [],
|
||||
userId,
|
||||
availableTools,
|
||||
});
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(mockGetAllServerConfigs).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle null/undefined tool entries gracefully', async () => {
|
||||
const result = await filterAuthorizedTools({
|
||||
tools: [null, undefined, '', 'web_search'],
|
||||
userId,
|
||||
availableTools,
|
||||
});
|
||||
|
||||
expect(result).toEqual(['web_search']);
|
||||
});
|
||||
|
||||
test('should call getAllServerConfigs with the correct userId', async () => {
|
||||
await filterAuthorizedTools({
|
||||
tools: [`tool${d}authorizedServer`],
|
||||
userId: 'specific-user-id',
|
||||
availableTools,
|
||||
});
|
||||
|
||||
expect(mockGetAllServerConfigs).toHaveBeenCalledWith('specific-user-id');
|
||||
});
|
||||
|
||||
test('should only call getAllServerConfigs once even with multiple MCP tools', async () => {
|
||||
await filterAuthorizedTools({
|
||||
tools: [`tool1${d}authorizedServer`, `tool2${d}anotherServer`, `tool3${d}unknownServer`],
|
||||
userId,
|
||||
availableTools,
|
||||
});
|
||||
|
||||
expect(mockGetAllServerConfigs).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
test('should preserve existing MCP tools when registry is unavailable', async () => {
|
||||
getMCPServersRegistry.mockImplementation(() => {
|
||||
throw new Error('MCPServersRegistry has not been initialized.');
|
||||
});
|
||||
|
||||
const existingTools = [`toolA${d}serverA`, `toolB${d}serverB`];
|
||||
|
||||
const result = await filterAuthorizedTools({
|
||||
tools: [...existingTools, `newTool${d}unknownServer`, 'web_search'],
|
||||
userId,
|
||||
availableTools,
|
||||
existingTools,
|
||||
});
|
||||
|
||||
expect(result).toContain(`toolA${d}serverA`);
|
||||
expect(result).toContain(`toolB${d}serverB`);
|
||||
expect(result).toContain('web_search');
|
||||
expect(result).not.toContain(`newTool${d}unknownServer`);
|
||||
});
|
||||
|
||||
test('should still reject all MCP tools when registry is unavailable and no existingTools', async () => {
|
||||
getMCPServersRegistry.mockImplementation(() => {
|
||||
throw new Error('MCPServersRegistry has not been initialized.');
|
||||
});
|
||||
|
||||
const result = await filterAuthorizedTools({
|
||||
tools: [`toolA${d}serverA`, 'web_search'],
|
||||
userId,
|
||||
availableTools,
|
||||
});
|
||||
|
||||
expect(result).toEqual(['web_search']);
|
||||
});
|
||||
|
||||
test('should not preserve malformed existing tools when registry is unavailable', async () => {
|
||||
getMCPServersRegistry.mockImplementation(() => {
|
||||
throw new Error('MCPServersRegistry has not been initialized.');
|
||||
});
|
||||
|
||||
const malformedTool = `a${d}b${d}c`;
|
||||
const result = await filterAuthorizedTools({
|
||||
tools: [malformedTool, `legit${d}serverA`, 'web_search'],
|
||||
userId,
|
||||
availableTools,
|
||||
existingTools: [malformedTool, `legit${d}serverA`],
|
||||
});
|
||||
|
||||
expect(result).toContain(`legit${d}serverA`);
|
||||
expect(result).toContain('web_search');
|
||||
expect(result).not.toContain(malformedTool);
|
||||
});
|
||||
|
||||
test('should reject malformed MCP tool keys with multiple delimiters', async () => {
|
||||
const result = await filterAuthorizedTools({
|
||||
tools: [
|
||||
`attack${d}victimServer${d}authorizedServer`,
|
||||
`legit${d}authorizedServer`,
|
||||
`a${d}b${d}c${d}d`,
|
||||
'web_search',
|
||||
],
|
||||
userId,
|
||||
availableTools,
|
||||
});
|
||||
|
||||
expect(result).toEqual([`legit${d}authorizedServer`, 'web_search']);
|
||||
expect(result).not.toContainEqual(expect.stringContaining('victimServer'));
|
||||
expect(result).not.toContainEqual(expect.stringContaining(`a${d}b`));
|
||||
});
|
||||
});
|
||||
|
||||
describe('createAgentHandler - MCP tool authorization', () => {
|
||||
test('should strip unauthorized MCP tools on create', async () => {
|
||||
mockReq.body = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'MCP Test Agent',
|
||||
tools: ['web_search', `validTool${d}authorizedServer`, `attack${d}forbiddenServer`],
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
const agent = mockRes.json.mock.calls[0][0];
|
||||
expect(agent.tools).toContain('web_search');
|
||||
expect(agent.tools).toContain(`validTool${d}authorizedServer`);
|
||||
expect(agent.tools).not.toContain(`attack${d}forbiddenServer`);
|
||||
});
|
||||
|
||||
test('should not 500 when MCP registry is uninitialized', async () => {
|
||||
getMCPServersRegistry.mockImplementation(() => {
|
||||
throw new Error('MCPServersRegistry has not been initialized.');
|
||||
});
|
||||
|
||||
mockReq.body = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'MCP Uninitialized Test',
|
||||
tools: [`tool${d}someServer`, 'web_search'],
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
const agent = mockRes.json.mock.calls[0][0];
|
||||
expect(agent.tools).toEqual(['web_search']);
|
||||
});
|
||||
|
||||
test('should store mcpServerNames only for authorized servers', async () => {
|
||||
mockReq.body = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'MCP Names Test',
|
||||
tools: [`toolA${d}authorizedServer`, `toolB${d}forbiddenServer`],
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
const agent = mockRes.json.mock.calls[0][0];
|
||||
const agentInDb = await Agent.findOne({ id: agent.id });
|
||||
expect(agentInDb.mcpServerNames).toContain('authorizedServer');
|
||||
expect(agentInDb.mcpServerNames).not.toContain('forbiddenServer');
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateAgentHandler - MCP tool authorization', () => {
|
||||
let existingAgentId;
|
||||
let existingAgentAuthorId;
|
||||
|
||||
beforeEach(async () => {
|
||||
existingAgentAuthorId = new mongoose.Types.ObjectId();
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${uuidv4()}`,
|
||||
name: 'Original Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: existingAgentAuthorId,
|
||||
tools: ['web_search', `existingTool${d}authorizedServer`],
|
||||
mcpServerNames: ['authorizedServer'],
|
||||
versions: [
|
||||
{
|
||||
name: 'Original Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['web_search', `existingTool${d}authorizedServer`],
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
existingAgentId = agent.id;
|
||||
});
|
||||
|
||||
test('should preserve existing MCP tools even if editor lacks access', async () => {
|
||||
mockGetAllServerConfigs.mockResolvedValue({});
|
||||
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
tools: ['web_search', `existingTool${d}authorizedServer`],
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`);
|
||||
expect(updatedAgent.tools).toContain('web_search');
|
||||
});
|
||||
|
||||
test('should reject newly added unauthorized MCP tools', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
tools: ['web_search', `existingTool${d}authorizedServer`, `attack${d}forbiddenServer`],
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.tools).toContain('web_search');
|
||||
expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`);
|
||||
expect(updatedAgent.tools).not.toContain(`attack${d}forbiddenServer`);
|
||||
});
|
||||
|
||||
test('should allow adding authorized MCP tools', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
tools: ['web_search', `existingTool${d}authorizedServer`, `newTool${d}anotherServer`],
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.tools).toContain(`newTool${d}anotherServer`);
|
||||
});
|
||||
|
||||
test('should not query MCP registry when no new MCP tools added', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
tools: ['web_search', `existingTool${d}authorizedServer`],
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockGetAllServerConfigs).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should preserve existing MCP tools when registry unavailable and user edits agent', async () => {
|
||||
getMCPServersRegistry.mockImplementation(() => {
|
||||
throw new Error('MCPServersRegistry has not been initialized.');
|
||||
});
|
||||
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Renamed After Restart',
|
||||
tools: ['web_search', `existingTool${d}authorizedServer`],
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`);
|
||||
expect(updatedAgent.tools).toContain('web_search');
|
||||
expect(updatedAgent.name).toBe('Renamed After Restart');
|
||||
});
|
||||
|
||||
test('should preserve existing MCP tools when server not in configs (disconnected)', async () => {
|
||||
mockGetAllServerConfigs.mockResolvedValue({});
|
||||
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Edited While Disconnected',
|
||||
tools: ['web_search', `existingTool${d}authorizedServer`],
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`);
|
||||
expect(updatedAgent.name).toBe('Edited While Disconnected');
|
||||
});
|
||||
});
|
||||
|
||||
describe('duplicateAgentHandler - MCP tool authorization', () => {
|
||||
let sourceAgentId;
|
||||
let sourceAgentAuthorId;
|
||||
|
||||
beforeEach(async () => {
|
||||
sourceAgentAuthorId = new mongoose.Types.ObjectId();
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${uuidv4()}`,
|
||||
name: 'Source Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: sourceAgentAuthorId,
|
||||
tools: ['web_search', `tool${d}authorizedServer`, `tool${d}forbiddenServer`],
|
||||
mcpServerNames: ['authorizedServer', 'forbiddenServer'],
|
||||
versions: [
|
||||
{
|
||||
name: 'Source Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['web_search', `tool${d}authorizedServer`, `tool${d}forbiddenServer`],
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
sourceAgentId = agent.id;
|
||||
});
|
||||
|
||||
test('should strip unauthorized MCP tools from duplicated agent', async () => {
|
||||
mockGetAllServerConfigs.mockResolvedValue({
|
||||
authorizedServer: { type: 'sse' },
|
||||
});
|
||||
|
||||
mockReq.user.id = sourceAgentAuthorId.toString();
|
||||
mockReq.params.id = sourceAgentId;
|
||||
|
||||
await duplicateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
const { agent: newAgent } = mockRes.json.mock.calls[0][0];
|
||||
expect(newAgent.id).not.toBe(sourceAgentId);
|
||||
expect(newAgent.tools).toContain('web_search');
|
||||
expect(newAgent.tools).toContain(`tool${d}authorizedServer`);
|
||||
expect(newAgent.tools).not.toContain(`tool${d}forbiddenServer`);
|
||||
|
||||
const agentInDb = await Agent.findOne({ id: newAgent.id });
|
||||
expect(agentInDb.mcpServerNames).toContain('authorizedServer');
|
||||
expect(agentInDb.mcpServerNames).not.toContain('forbiddenServer');
|
||||
});
|
||||
|
||||
test('should preserve source agent MCP tools when registry is unavailable', async () => {
|
||||
getMCPServersRegistry.mockImplementation(() => {
|
||||
throw new Error('MCPServersRegistry has not been initialized.');
|
||||
});
|
||||
|
||||
mockReq.user.id = sourceAgentAuthorId.toString();
|
||||
mockReq.params.id = sourceAgentId;
|
||||
|
||||
await duplicateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
const { agent: newAgent } = mockRes.json.mock.calls[0][0];
|
||||
expect(newAgent.tools).toContain('web_search');
|
||||
expect(newAgent.tools).toContain(`tool${d}authorizedServer`);
|
||||
expect(newAgent.tools).toContain(`tool${d}forbiddenServer`);
|
||||
});
|
||||
});
|
||||
|
||||
describe('revertAgentVersionHandler - MCP tool authorization', () => {
|
||||
let existingAgentId;
|
||||
let existingAgentAuthorId;
|
||||
|
||||
beforeEach(async () => {
|
||||
existingAgentAuthorId = new mongoose.Types.ObjectId();
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${uuidv4()}`,
|
||||
name: 'Reverted Agent V2',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: existingAgentAuthorId,
|
||||
tools: ['web_search'],
|
||||
versions: [
|
||||
{
|
||||
name: 'Reverted Agent V1',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['web_search', `oldTool${d}revokedServer`],
|
||||
createdAt: new Date(Date.now() - 10000),
|
||||
updatedAt: new Date(Date.now() - 10000),
|
||||
},
|
||||
{
|
||||
name: 'Reverted Agent V2',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['web_search'],
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
existingAgentId = agent.id;
|
||||
});
|
||||
|
||||
test('should strip unauthorized MCP tools after reverting to a previous version', async () => {
|
||||
mockGetAllServerConfigs.mockResolvedValue({
|
||||
authorizedServer: { type: 'sse' },
|
||||
});
|
||||
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = { version_index: 0 };
|
||||
|
||||
await revertAgentVersionHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
const result = mockRes.json.mock.calls[0][0];
|
||||
expect(result.tools).toContain('web_search');
|
||||
expect(result.tools).not.toContain(`oldTool${d}revokedServer`);
|
||||
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.tools).toContain('web_search');
|
||||
expect(agentInDb.tools).not.toContain(`oldTool${d}revokedServer`);
|
||||
});
|
||||
|
||||
test('should keep authorized MCP tools after revert', async () => {
|
||||
await Agent.updateOne(
|
||||
{ id: existingAgentId },
|
||||
{ $set: { 'versions.0.tools': ['web_search', `tool${d}authorizedServer`] } },
|
||||
);
|
||||
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = { version_index: 0 };
|
||||
|
||||
await revertAgentVersionHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
const result = mockRes.json.mock.calls[0][0];
|
||||
expect(result.tools).toContain('web_search');
|
||||
expect(result.tools).toContain(`tool${d}authorizedServer`);
|
||||
});
|
||||
|
||||
test('should preserve version MCP tools when registry is unavailable on revert', async () => {
|
||||
await Agent.updateOne(
|
||||
{ id: existingAgentId },
|
||||
{
|
||||
$set: {
|
||||
'versions.0.tools': [
|
||||
'web_search',
|
||||
`validTool${d}authorizedServer`,
|
||||
`otherTool${d}anotherServer`,
|
||||
],
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
getMCPServersRegistry.mockImplementation(() => {
|
||||
throw new Error('MCPServersRegistry has not been initialized.');
|
||||
});
|
||||
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = { version_index: 0 };
|
||||
|
||||
await revertAgentVersionHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
const result = mockRes.json.mock.calls[0][0];
|
||||
expect(result.tools).toContain('web_search');
|
||||
expect(result.tools).toContain(`validTool${d}authorizedServer`);
|
||||
expect(result.tools).toContain(`otherTool${d}anotherServer`);
|
||||
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.tools).toContain(`validTool${d}authorizedServer`);
|
||||
expect(agentInDb.tools).toContain(`otherTool${d}anotherServer`);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -21,13 +21,14 @@ const {
|
|||
createOpenAIContentAggregator,
|
||||
isChatCompletionValidationFailure,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
buildSummarizationHandlers,
|
||||
markSummarizationUsage,
|
||||
createToolEndCallback,
|
||||
agentLogHandlerObj,
|
||||
} = require('~/server/controllers/agents/callbacks');
|
||||
const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService');
|
||||
const { createToolEndCallback } = require('~/server/controllers/agents/callbacks');
|
||||
const { findAccessibleResources } = require('~/server/services/PermissionService');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getMultiplier, getCacheMultiplier } = require('~/models/tx');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getAgent, getAgents } = require('~/models/Agent');
|
||||
const db = require('~/models');
|
||||
|
||||
/**
|
||||
|
|
@ -139,7 +140,7 @@ const OpenAIChatCompletionController = async (req, res) => {
|
|||
const agentId = request.model;
|
||||
|
||||
// Look up the agent
|
||||
const agent = await getAgent({ id: agentId });
|
||||
const agent = await db.getAgent({ id: agentId });
|
||||
if (!agent) {
|
||||
return sendErrorResponse(
|
||||
res,
|
||||
|
|
@ -151,8 +152,6 @@ const OpenAIChatCompletionController = async (req, res) => {
|
|||
}
|
||||
|
||||
const responseId = `chatcmpl-${nanoid()}`;
|
||||
const conversationId = request.conversation_id ?? nanoid();
|
||||
const parentMessageId = request.parent_message_id ?? null;
|
||||
const created = Math.floor(Date.now() / 1000);
|
||||
|
||||
/** @type {import('@librechat/api').OpenAIResponseContext} — key must be `requestId` to match the type used by createChunk/buildNonStreamingResponse */
|
||||
|
|
@ -178,6 +177,23 @@ const OpenAIChatCompletionController = async (req, res) => {
|
|||
});
|
||||
|
||||
try {
|
||||
if (request.conversation_id != null) {
|
||||
if (typeof request.conversation_id !== 'string') {
|
||||
return sendErrorResponse(
|
||||
res,
|
||||
400,
|
||||
'conversation_id must be a string',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
if (!(await db.getConvo(req.user?.id, request.conversation_id))) {
|
||||
return sendErrorResponse(res, 404, 'Conversation not found', 'invalid_request_error');
|
||||
}
|
||||
}
|
||||
|
||||
const conversationId = request.conversation_id ?? nanoid();
|
||||
const parentMessageId = request.parent_message_id ?? null;
|
||||
|
||||
// Build allowed providers set
|
||||
const allowedProviders = new Set(
|
||||
appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders,
|
||||
|
|
@ -206,7 +222,7 @@ const OpenAIChatCompletionController = async (req, res) => {
|
|||
isInitialAgent: true,
|
||||
},
|
||||
{
|
||||
getConvoFiles,
|
||||
getConvoFiles: db.getConvoFiles,
|
||||
getFiles: db.getFiles,
|
||||
getUserKey: db.getUserKey,
|
||||
getMessages: db.getMessages,
|
||||
|
|
@ -265,19 +281,22 @@ const OpenAIChatCompletionController = async (req, res) => {
|
|||
toolRegistry: primaryConfig.toolRegistry,
|
||||
userMCPAuthMap: primaryConfig.userMCPAuthMap,
|
||||
tool_resources: primaryConfig.tool_resources,
|
||||
actionsEnabled: primaryConfig.actionsEnabled,
|
||||
});
|
||||
},
|
||||
toolEndCallback,
|
||||
};
|
||||
|
||||
const summarizationConfig = appConfig?.summarization;
|
||||
|
||||
const openaiMessages = convertMessages(request.messages);
|
||||
|
||||
const toolSet = buildToolSet(primaryConfig);
|
||||
const { messages: formattedMessages, indexTokenCountMap } = formatAgentMessages(
|
||||
openaiMessages,
|
||||
{},
|
||||
toolSet,
|
||||
);
|
||||
const {
|
||||
messages: formattedMessages,
|
||||
indexTokenCountMap,
|
||||
summary: initialSummary,
|
||||
} = formatAgentMessages(openaiMessages, {}, toolSet);
|
||||
|
||||
/**
|
||||
* Create a simple handler that processes data
|
||||
|
|
@ -420,24 +439,30 @@ const OpenAIChatCompletionController = async (req, res) => {
|
|||
}),
|
||||
|
||||
// Usage tracking
|
||||
on_chat_model_end: createHandler((data) => {
|
||||
const usage = data?.output?.usage_metadata;
|
||||
if (usage) {
|
||||
collectedUsage.push(usage);
|
||||
const target = isStreaming ? tracker : aggregator;
|
||||
target.usage.promptTokens += usage.input_tokens ?? 0;
|
||||
target.usage.completionTokens += usage.output_tokens ?? 0;
|
||||
}
|
||||
}),
|
||||
on_chat_model_end: {
|
||||
handle: (_event, data, metadata) => {
|
||||
const usage = data?.output?.usage_metadata;
|
||||
if (usage) {
|
||||
const taggedUsage = markSummarizationUsage(usage, metadata);
|
||||
collectedUsage.push(taggedUsage);
|
||||
const target = isStreaming ? tracker : aggregator;
|
||||
target.usage.promptTokens += taggedUsage.input_tokens ?? 0;
|
||||
target.usage.completionTokens += taggedUsage.output_tokens ?? 0;
|
||||
}
|
||||
},
|
||||
},
|
||||
on_run_step_completed: createHandler(),
|
||||
// Use proper ToolEndHandler for processing artifacts (images, file citations, code output)
|
||||
on_tool_end: new ToolEndHandler(toolEndCallback, logger),
|
||||
on_chain_stream: createHandler(),
|
||||
on_chain_end: createHandler(),
|
||||
on_agent_update: createHandler(),
|
||||
on_agent_log: agentLogHandlerObj,
|
||||
on_custom_event: createHandler(),
|
||||
// Event-driven tool execution handler
|
||||
on_tool_execute: createToolExecuteHandler(toolExecuteOptions),
|
||||
...(summarizationConfig?.enabled !== false
|
||||
? buildSummarizationHandlers({ isStreaming, res })
|
||||
: {}),
|
||||
};
|
||||
|
||||
// Create and run the agent
|
||||
|
|
@ -450,7 +475,9 @@ const OpenAIChatCompletionController = async (req, res) => {
|
|||
agents: [primaryConfig],
|
||||
messages: formattedMessages,
|
||||
indexTokenCountMap,
|
||||
initialSummary,
|
||||
runId: responseId,
|
||||
summarizationConfig,
|
||||
signal: abortController.signal,
|
||||
customHandlers: handlers,
|
||||
requestBody: {
|
||||
|
|
@ -495,9 +522,9 @@ const OpenAIChatCompletionController = async (req, res) => {
|
|||
const transactionsConfig = getTransactionsConfig(appConfig);
|
||||
recordCollectedUsage(
|
||||
{
|
||||
spendTokens,
|
||||
spendStructuredTokens,
|
||||
pricing: { getMultiplier, getCacheMultiplier },
|
||||
spendTokens: db.spendTokens,
|
||||
spendStructuredTokens: db.spendStructuredTokens,
|
||||
pricing: { getMultiplier: db.getMultiplier, getCacheMultiplier: db.getCacheMultiplier },
|
||||
bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance },
|
||||
},
|
||||
{
|
||||
|
|
@ -611,7 +638,7 @@ const ListModelsController = async (req, res) => {
|
|||
// Get the accessible agents
|
||||
let agents = [];
|
||||
if (accessibleAgentIds.length > 0) {
|
||||
agents = await getAgents({ _id: { $in: accessibleAgentIds } });
|
||||
agents = await db.getAgents({ _id: { $in: accessibleAgentIds } });
|
||||
}
|
||||
|
||||
const models = agents.map((agent) => ({
|
||||
|
|
@ -654,7 +681,7 @@ const GetModelController = async (req, res) => {
|
|||
return sendErrorResponse(res, 401, 'Authentication required', 'auth_error');
|
||||
}
|
||||
|
||||
const agent = await getAgent({ id: model });
|
||||
const agent = await db.getAgent({ id: model });
|
||||
|
||||
if (!agent) {
|
||||
return sendErrorResponse(
|
||||
|
|
|
|||
|
|
@ -18,17 +18,11 @@ const mockRecordCollectedUsage = jest
|
|||
.fn()
|
||||
.mockResolvedValue({ input_tokens: 100, output_tokens: 50 });
|
||||
|
||||
jest.mock('~/models/spendTokens', () => ({
|
||||
jest.mock('~/models', () => ({
|
||||
spendTokens: (...args) => mockSpendTokens(...args),
|
||||
spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/tx', () => ({
|
||||
getMultiplier: mockGetMultiplier,
|
||||
getCacheMultiplier: mockGetCacheMultiplier,
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
updateBalance: mockUpdateBalance,
|
||||
bulkInsertTransactions: mockBulkInsertTransactions,
|
||||
}));
|
||||
|
|
|
|||
|
|
@ -131,9 +131,15 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
|
|||
partialMessage.agent_id = req.body.agent_id;
|
||||
}
|
||||
|
||||
await saveMessage(req, partialMessage, {
|
||||
context: 'api/server/controllers/agents/request.js - partial response on disconnect',
|
||||
});
|
||||
await saveMessage(
|
||||
{
|
||||
userId: req?.user?.id,
|
||||
isTemporary: req?.body?.isTemporary,
|
||||
interfaceConfig: req?.config?.interfaceConfig,
|
||||
},
|
||||
partialMessage,
|
||||
{ context: 'api/server/controllers/agents/request.js - partial response on disconnect' },
|
||||
);
|
||||
|
||||
logger.debug(
|
||||
`[ResumableAgentController] Saved partial response for ${streamId}, content parts: ${aggregatedContent.length}`,
|
||||
|
|
@ -271,8 +277,14 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
|
|||
|
||||
// Save user message BEFORE sending final event to avoid race condition
|
||||
// where client refetch happens before database is updated
|
||||
const reqCtx = {
|
||||
userId: req?.user?.id,
|
||||
isTemporary: req?.body?.isTemporary,
|
||||
interfaceConfig: req?.config?.interfaceConfig,
|
||||
};
|
||||
|
||||
if (!client.skipSaveUserMessage && userMessage) {
|
||||
await saveMessage(req, userMessage, {
|
||||
await saveMessage(reqCtx, userMessage, {
|
||||
context: 'api/server/controllers/agents/request.js - resumable user message',
|
||||
});
|
||||
}
|
||||
|
|
@ -282,7 +294,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
|
|||
// before the response is saved to the database, causing orphaned parentMessageIds.
|
||||
if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) {
|
||||
await saveMessage(
|
||||
req,
|
||||
reqCtx,
|
||||
{ ...response, user: userId, unfinished: wasAbortedBeforeComplete },
|
||||
{ context: 'api/server/controllers/agents/request.js - resumable response end' },
|
||||
);
|
||||
|
|
@ -661,7 +673,11 @@ const _LegacyAgentController = async (req, res, next, initializeClient, addTitle
|
|||
// Save the message if needed
|
||||
if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) {
|
||||
await saveMessage(
|
||||
req,
|
||||
{
|
||||
userId: req?.user?.id,
|
||||
isTemporary: req?.body?.isTemporary,
|
||||
interfaceConfig: req?.config?.interfaceConfig,
|
||||
},
|
||||
{ ...finalResponse, user: userId },
|
||||
{ context: 'api/server/controllers/agents/request.js - response end' },
|
||||
);
|
||||
|
|
@ -690,9 +706,15 @@ const _LegacyAgentController = async (req, res, next, initializeClient, addTitle
|
|||
|
||||
// Save user message if needed
|
||||
if (!client.skipSaveUserMessage) {
|
||||
await saveMessage(req, userMessage, {
|
||||
context: "api/server/controllers/agents/request.js - don't skip saving user message",
|
||||
});
|
||||
await saveMessage(
|
||||
{
|
||||
userId: req?.user?.id,
|
||||
isTemporary: req?.body?.isTemporary,
|
||||
interfaceConfig: req?.config?.interfaceConfig,
|
||||
},
|
||||
userMessage,
|
||||
{ context: "api/server/controllers/agents/request.js - don't skip saving user message" },
|
||||
);
|
||||
}
|
||||
|
||||
// Add title if needed - extract minimal data
|
||||
|
|
|
|||
|
|
@ -32,14 +32,13 @@ const {
|
|||
} = require('@librechat/api');
|
||||
const {
|
||||
createResponsesToolEndCallback,
|
||||
buildSummarizationHandlers,
|
||||
markSummarizationUsage,
|
||||
createToolEndCallback,
|
||||
agentLogHandlerObj,
|
||||
} = require('~/server/controllers/agents/callbacks');
|
||||
const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService');
|
||||
const { findAccessibleResources } = require('~/server/services/PermissionService');
|
||||
const { getConvoFiles, saveConvo, getConvo } = require('~/models/Conversation');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getMultiplier, getCacheMultiplier } = require('~/models/tx');
|
||||
const { getAgent, getAgents } = require('~/models/Agent');
|
||||
const db = require('~/models');
|
||||
|
||||
/** @type {import('@librechat/api').AppConfig | null} */
|
||||
|
|
@ -214,8 +213,12 @@ async function saveResponseOutput(req, conversationId, responseId, response, age
|
|||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function saveConversation(req, conversationId, agentId, agent) {
|
||||
await saveConvo(
|
||||
req,
|
||||
await db.saveConvo(
|
||||
{
|
||||
userId: req?.user?.id,
|
||||
isTemporary: req?.body?.isTemporary,
|
||||
interfaceConfig: req?.config?.interfaceConfig,
|
||||
},
|
||||
{
|
||||
conversationId,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
|
|
@ -277,9 +280,10 @@ const createResponse = async (req, res) => {
|
|||
const request = validation.request;
|
||||
const agentId = request.model;
|
||||
const isStreaming = request.stream === true;
|
||||
const summarizationConfig = req.config?.summarization;
|
||||
|
||||
// Look up the agent
|
||||
const agent = await getAgent({ id: agentId });
|
||||
const agent = await db.getAgent({ id: agentId });
|
||||
if (!agent) {
|
||||
return sendResponsesErrorResponse(
|
||||
res,
|
||||
|
|
@ -292,10 +296,6 @@ const createResponse = async (req, res) => {
|
|||
|
||||
// Generate IDs
|
||||
const responseId = generateResponseId();
|
||||
const conversationId = request.previous_response_id ?? uuidv4();
|
||||
const parentMessageId = null;
|
||||
|
||||
// Create response context
|
||||
const context = createResponseContext(request, responseId);
|
||||
|
||||
logger.debug(
|
||||
|
|
@ -314,6 +314,23 @@ const createResponse = async (req, res) => {
|
|||
});
|
||||
|
||||
try {
|
||||
if (request.previous_response_id != null) {
|
||||
if (typeof request.previous_response_id !== 'string') {
|
||||
return sendResponsesErrorResponse(
|
||||
res,
|
||||
400,
|
||||
'previous_response_id must be a string',
|
||||
'invalid_request',
|
||||
);
|
||||
}
|
||||
if (!(await db.getConvo(req.user?.id, request.previous_response_id))) {
|
||||
return sendResponsesErrorResponse(res, 404, 'Conversation not found', 'not_found');
|
||||
}
|
||||
}
|
||||
|
||||
const conversationId = request.previous_response_id ?? uuidv4();
|
||||
const parentMessageId = null;
|
||||
|
||||
// Build allowed providers set
|
||||
const allowedProviders = new Set(
|
||||
appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders,
|
||||
|
|
@ -342,7 +359,7 @@ const createResponse = async (req, res) => {
|
|||
isInitialAgent: true,
|
||||
},
|
||||
{
|
||||
getConvoFiles,
|
||||
getConvoFiles: db.getConvoFiles,
|
||||
getFiles: db.getFiles,
|
||||
getUserKey: db.getUserKey,
|
||||
getMessages: db.getMessages,
|
||||
|
|
@ -374,11 +391,11 @@ const createResponse = async (req, res) => {
|
|||
const allMessages = [...previousMessages, ...inputMessages];
|
||||
|
||||
const toolSet = buildToolSet(primaryConfig);
|
||||
const { messages: formattedMessages, indexTokenCountMap } = formatAgentMessages(
|
||||
allMessages,
|
||||
{},
|
||||
toolSet,
|
||||
);
|
||||
const {
|
||||
messages: formattedMessages,
|
||||
indexTokenCountMap,
|
||||
summary: initialSummary,
|
||||
} = formatAgentMessages(allMessages, {}, toolSet);
|
||||
|
||||
// Create tracker for streaming or aggregator for non-streaming
|
||||
const tracker = actuallyStreaming ? createResponseTracker() : null;
|
||||
|
|
@ -429,6 +446,7 @@ const createResponse = async (req, res) => {
|
|||
toolRegistry: primaryConfig.toolRegistry,
|
||||
userMCPAuthMap: primaryConfig.userMCPAuthMap,
|
||||
tool_resources: primaryConfig.tool_resources,
|
||||
actionsEnabled: primaryConfig.actionsEnabled,
|
||||
});
|
||||
},
|
||||
toolEndCallback,
|
||||
|
|
@ -441,11 +459,12 @@ const createResponse = async (req, res) => {
|
|||
on_run_step: responsesHandlers.on_run_step,
|
||||
on_run_step_delta: responsesHandlers.on_run_step_delta,
|
||||
on_chat_model_end: {
|
||||
handle: (event, data) => {
|
||||
handle: (event, data, metadata) => {
|
||||
responsesHandlers.on_chat_model_end.handle(event, data);
|
||||
const usage = data?.output?.usage_metadata;
|
||||
if (usage) {
|
||||
collectedUsage.push(usage);
|
||||
const taggedUsage = markSummarizationUsage(usage, metadata);
|
||||
collectedUsage.push(taggedUsage);
|
||||
}
|
||||
},
|
||||
},
|
||||
|
|
@ -456,6 +475,10 @@ const createResponse = async (req, res) => {
|
|||
on_agent_update: { handle: () => {} },
|
||||
on_custom_event: { handle: () => {} },
|
||||
on_tool_execute: createToolExecuteHandler(toolExecuteOptions),
|
||||
on_agent_log: agentLogHandlerObj,
|
||||
...(summarizationConfig?.enabled !== false
|
||||
? buildSummarizationHandlers({ isStreaming: actuallyStreaming, res })
|
||||
: {}),
|
||||
};
|
||||
|
||||
// Create and run the agent
|
||||
|
|
@ -466,7 +489,9 @@ const createResponse = async (req, res) => {
|
|||
agents: [primaryConfig],
|
||||
messages: formattedMessages,
|
||||
indexTokenCountMap,
|
||||
initialSummary,
|
||||
runId: responseId,
|
||||
summarizationConfig,
|
||||
signal: abortController.signal,
|
||||
customHandlers: handlers,
|
||||
requestBody: {
|
||||
|
|
@ -511,9 +536,9 @@ const createResponse = async (req, res) => {
|
|||
const transactionsConfig = getTransactionsConfig(req.config);
|
||||
recordCollectedUsage(
|
||||
{
|
||||
spendTokens,
|
||||
spendStructuredTokens,
|
||||
pricing: { getMultiplier, getCacheMultiplier },
|
||||
spendTokens: db.spendTokens,
|
||||
spendStructuredTokens: db.spendStructuredTokens,
|
||||
pricing: { getMultiplier: db.getMultiplier, getCacheMultiplier: db.getCacheMultiplier },
|
||||
bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance },
|
||||
},
|
||||
{
|
||||
|
|
@ -586,6 +611,7 @@ const createResponse = async (req, res) => {
|
|||
toolRegistry: primaryConfig.toolRegistry,
|
||||
userMCPAuthMap: primaryConfig.userMCPAuthMap,
|
||||
tool_resources: primaryConfig.tool_resources,
|
||||
actionsEnabled: primaryConfig.actionsEnabled,
|
||||
});
|
||||
},
|
||||
toolEndCallback,
|
||||
|
|
@ -597,11 +623,12 @@ const createResponse = async (req, res) => {
|
|||
on_run_step: aggregatorHandlers.on_run_step,
|
||||
on_run_step_delta: aggregatorHandlers.on_run_step_delta,
|
||||
on_chat_model_end: {
|
||||
handle: (event, data) => {
|
||||
handle: (event, data, metadata) => {
|
||||
aggregatorHandlers.on_chat_model_end.handle(event, data);
|
||||
const usage = data?.output?.usage_metadata;
|
||||
if (usage) {
|
||||
collectedUsage.push(usage);
|
||||
const taggedUsage = markSummarizationUsage(usage, metadata);
|
||||
collectedUsage.push(taggedUsage);
|
||||
}
|
||||
},
|
||||
},
|
||||
|
|
@ -612,6 +639,10 @@ const createResponse = async (req, res) => {
|
|||
on_agent_update: { handle: () => {} },
|
||||
on_custom_event: { handle: () => {} },
|
||||
on_tool_execute: createToolExecuteHandler(toolExecuteOptions),
|
||||
on_agent_log: agentLogHandlerObj,
|
||||
...(summarizationConfig?.enabled !== false
|
||||
? buildSummarizationHandlers({ isStreaming: false, res })
|
||||
: {}),
|
||||
};
|
||||
|
||||
const userId = req.user?.id ?? 'api-user';
|
||||
|
|
@ -621,7 +652,9 @@ const createResponse = async (req, res) => {
|
|||
agents: [primaryConfig],
|
||||
messages: formattedMessages,
|
||||
indexTokenCountMap,
|
||||
initialSummary,
|
||||
runId: responseId,
|
||||
summarizationConfig,
|
||||
signal: abortController.signal,
|
||||
customHandlers: handlers,
|
||||
requestBody: {
|
||||
|
|
@ -665,9 +698,9 @@ const createResponse = async (req, res) => {
|
|||
const transactionsConfig = getTransactionsConfig(req.config);
|
||||
recordCollectedUsage(
|
||||
{
|
||||
spendTokens,
|
||||
spendStructuredTokens,
|
||||
pricing: { getMultiplier, getCacheMultiplier },
|
||||
spendTokens: db.spendTokens,
|
||||
spendStructuredTokens: db.spendStructuredTokens,
|
||||
pricing: { getMultiplier: db.getMultiplier, getCacheMultiplier: db.getCacheMultiplier },
|
||||
bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance },
|
||||
},
|
||||
{
|
||||
|
|
@ -767,7 +800,7 @@ const listModels = async (req, res) => {
|
|||
// Get the accessible agents
|
||||
let agents = [];
|
||||
if (accessibleAgentIds.length > 0) {
|
||||
agents = await getAgents({ _id: { $in: accessibleAgentIds } });
|
||||
agents = await db.getAgents({ _id: { $in: accessibleAgentIds } });
|
||||
}
|
||||
|
||||
// Convert to models format
|
||||
|
|
@ -817,7 +850,7 @@ const getResponse = async (req, res) => {
|
|||
|
||||
// The responseId could be either the response ID or the conversation ID
|
||||
// Try to find a conversation with this ID
|
||||
const conversation = await getConvo(userId, responseId);
|
||||
const conversation = await db.getConvo(userId, responseId);
|
||||
|
||||
if (!conversation) {
|
||||
return sendResponsesErrorResponse(
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ const fs = require('fs').promises;
|
|||
const { nanoid } = require('nanoid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
refreshS3Url,
|
||||
agentCreateSchema,
|
||||
agentUpdateSchema,
|
||||
refreshListAvatars,
|
||||
collectEdgeAgentIds,
|
||||
mergeAgentOcrConversion,
|
||||
MAX_AVATAR_REFRESH_AGENTS,
|
||||
convertOcrToContextInPlace,
|
||||
|
|
@ -24,30 +26,21 @@ const {
|
|||
actionDelimiter,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getListAgentsByAccess,
|
||||
countPromotedAgents,
|
||||
revertAgentVersion,
|
||||
createAgent,
|
||||
updateAgent,
|
||||
deleteAgent,
|
||||
getAgent,
|
||||
} = require('~/models/Agent');
|
||||
const {
|
||||
findPubliclyAccessibleResources,
|
||||
getResourcePermissionsMap,
|
||||
findAccessibleResources,
|
||||
hasPublicPermission,
|
||||
grantPermission,
|
||||
} = require('~/server/services/PermissionService');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { getCategoriesWithCounts, deleteFileByFilter } = require('~/models');
|
||||
const { resizeAvatar } = require('~/server/services/Files/images/avatar');
|
||||
const { getFileStrategy } = require('~/server/utils/getFileStrategy');
|
||||
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
|
||||
const { filterFile } = require('~/server/services/Files/process');
|
||||
const { updateAction, getActions } = require('~/models/Action');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { getMCPServersRegistry } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const db = require('~/models');
|
||||
|
||||
const systemTools = {
|
||||
[Tools.execute_code]: true,
|
||||
|
|
@ -58,6 +51,114 @@ const systemTools = {
|
|||
const MAX_SEARCH_LEN = 100;
|
||||
const escapeRegex = (str = '') => str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
|
||||
/**
|
||||
* Validates that the requesting user has VIEW access to every agent referenced in edges.
|
||||
* Agents that do not exist in the database are skipped — at create time, the `from` field
|
||||
* often references the agent being built, which has no DB record yet.
|
||||
* @param {import('librechat-data-provider').GraphEdge[]} edges
|
||||
* @param {string} userId
|
||||
* @param {string} userRole - Used for group/role principal resolution
|
||||
* @returns {Promise<string[]>} Agent IDs the user cannot VIEW (empty if all accessible)
|
||||
*/
|
||||
const validateEdgeAgentAccess = async (edges, userId, userRole) => {
|
||||
const edgeAgentIds = collectEdgeAgentIds(edges);
|
||||
if (edgeAgentIds.size === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const agents = await db.getAgents({ id: { $in: [...edgeAgentIds] } });
|
||||
|
||||
if (agents.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const permissionsMap = await getResourcePermissionsMap({
|
||||
userId,
|
||||
role: userRole,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceIds: agents.map((a) => a._id),
|
||||
});
|
||||
|
||||
return agents
|
||||
.filter((a) => {
|
||||
const bits = permissionsMap.get(a._id.toString()) ?? 0;
|
||||
return (bits & PermissionBits.VIEW) === 0;
|
||||
})
|
||||
.map((a) => a.id);
|
||||
};
|
||||
|
||||
/**
|
||||
* Filters tools to only include those the user is authorized to use.
|
||||
* MCP tools must match the exact format `{toolName}_mcp_{serverName}` (exactly 2 segments).
|
||||
* Multi-delimiter keys are rejected to prevent authorization/execution mismatch.
|
||||
* Non-MCP tools must appear in availableTools (global tool cache) or systemTools.
|
||||
*
|
||||
* When `existingTools` is provided and the MCP registry is unavailable (e.g. server restart),
|
||||
* tools already present on the agent are preserved rather than stripped — they were validated
|
||||
* when originally added, and we cannot re-verify them without the registry.
|
||||
* @param {object} params
|
||||
* @param {string[]} params.tools - Raw tool strings from the request
|
||||
* @param {string} params.userId - Requesting user ID for MCP server access check
|
||||
* @param {Record<string, unknown>} params.availableTools - Global non-MCP tool cache
|
||||
* @param {string[]} [params.existingTools] - Tools already persisted on the agent document
|
||||
* @returns {Promise<string[]>} Only the authorized subset of tools
|
||||
*/
|
||||
const filterAuthorizedTools = async ({ tools, userId, availableTools, existingTools }) => {
|
||||
const filteredTools = [];
|
||||
let mcpServerConfigs;
|
||||
let registryUnavailable = false;
|
||||
const existingToolSet = existingTools?.length ? new Set(existingTools) : null;
|
||||
|
||||
for (const tool of tools) {
|
||||
if (availableTools[tool] || systemTools[tool]) {
|
||||
filteredTools.push(tool);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!tool?.includes(Constants.mcp_delimiter)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mcpServerConfigs === undefined) {
|
||||
try {
|
||||
mcpServerConfigs = (await getMCPServersRegistry().getAllServerConfigs(userId)) ?? {};
|
||||
} catch (e) {
|
||||
logger.warn(
|
||||
'[filterAuthorizedTools] MCP registry unavailable, filtering all MCP tools',
|
||||
e.message,
|
||||
);
|
||||
mcpServerConfigs = {};
|
||||
registryUnavailable = true;
|
||||
}
|
||||
}
|
||||
|
||||
const parts = tool.split(Constants.mcp_delimiter);
|
||||
if (parts.length !== 2) {
|
||||
logger.warn(
|
||||
`[filterAuthorizedTools] Rejected malformed MCP tool key "${tool}" for user ${userId}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (registryUnavailable && existingToolSet?.has(tool)) {
|
||||
filteredTools.push(tool);
|
||||
continue;
|
||||
}
|
||||
|
||||
const [, serverName] = parts;
|
||||
if (!serverName || !Object.hasOwn(mcpServerConfigs, serverName)) {
|
||||
logger.warn(
|
||||
`[filterAuthorizedTools] Rejected MCP tool "${tool}" — server "${serverName}" not accessible to user ${userId}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
filteredTools.push(tool);
|
||||
}
|
||||
|
||||
return filteredTools;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates an Agent.
|
||||
* @route POST /Agents
|
||||
|
|
@ -75,24 +176,26 @@ const createAgentHandler = async (req, res) => {
|
|||
agentData.model_parameters = removeNullishValues(agentData.model_parameters, true);
|
||||
}
|
||||
|
||||
const { id: userId } = req.user;
|
||||
const { id: userId, role: userRole } = req.user;
|
||||
|
||||
if (agentData.edges?.length) {
|
||||
const unauthorized = await validateEdgeAgentAccess(agentData.edges, userId, userRole);
|
||||
if (unauthorized.length > 0) {
|
||||
return res.status(403).json({
|
||||
error: 'You do not have access to one or more agents referenced in edges',
|
||||
agent_ids: unauthorized,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
agentData.id = `agent_${nanoid()}`;
|
||||
agentData.author = userId;
|
||||
agentData.tools = [];
|
||||
|
||||
const availableTools = (await getCachedTools()) ?? {};
|
||||
for (const tool of tools) {
|
||||
if (availableTools[tool]) {
|
||||
agentData.tools.push(tool);
|
||||
} else if (systemTools[tool]) {
|
||||
agentData.tools.push(tool);
|
||||
} else if (tool.includes(Constants.mcp_delimiter)) {
|
||||
agentData.tools.push(tool);
|
||||
}
|
||||
}
|
||||
agentData.tools = await filterAuthorizedTools({ tools, userId, availableTools });
|
||||
|
||||
const agent = await createAgent(agentData);
|
||||
const agent = await db.createAgent(agentData);
|
||||
|
||||
try {
|
||||
await Promise.all([
|
||||
|
|
@ -152,7 +255,7 @@ const getAgentHandler = async (req, res, expandProperties = false) => {
|
|||
|
||||
// Permissions are validated by middleware before calling this function
|
||||
// Simply load the agent by ID
|
||||
const agent = await getAgent({ id });
|
||||
const agent = await db.getAgent({ id });
|
||||
|
||||
if (!agent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
|
|
@ -173,9 +276,6 @@ const getAgentHandler = async (req, res, expandProperties = false) => {
|
|||
|
||||
agent.author = agent.author.toString();
|
||||
|
||||
// @deprecated - isCollaborative replaced by ACL permissions
|
||||
agent.isCollaborative = !!agent.isCollaborative;
|
||||
|
||||
// Check if agent is public
|
||||
const isPublic = await hasPublicPermission({
|
||||
resourceType: ResourceType.AGENT,
|
||||
|
|
@ -199,9 +299,6 @@ const getAgentHandler = async (req, res, expandProperties = false) => {
|
|||
author: agent.author,
|
||||
provider: agent.provider,
|
||||
model: agent.model,
|
||||
projectIds: agent.projectIds,
|
||||
// @deprecated - isCollaborative replaced by ACL permissions
|
||||
isCollaborative: agent.isCollaborative,
|
||||
isPublic: agent.isPublic,
|
||||
version: agent.version,
|
||||
// Safe metadata
|
||||
|
|
@ -243,10 +340,21 @@ const updateAgentHandler = async (req, res) => {
|
|||
updateData.avatar = avatarField;
|
||||
}
|
||||
|
||||
if (updateData.edges?.length) {
|
||||
const { id: userId, role: userRole } = req.user;
|
||||
const unauthorized = await validateEdgeAgentAccess(updateData.edges, userId, userRole);
|
||||
if (unauthorized.length > 0) {
|
||||
return res.status(403).json({
|
||||
error: 'You do not have access to one or more agents referenced in edges',
|
||||
agent_ids: unauthorized,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert OCR to context in incoming updateData
|
||||
convertOcrToContextInPlace(updateData);
|
||||
|
||||
const existingAgent = await getAgent({ id });
|
||||
const existingAgent = await db.getAgent({ id });
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
|
|
@ -261,9 +369,29 @@ const updateAgentHandler = async (req, res) => {
|
|||
updateData.tools = ocrConversion.tools;
|
||||
}
|
||||
|
||||
if (updateData.tools) {
|
||||
const existingToolSet = new Set(existingAgent.tools ?? []);
|
||||
const newMCPTools = updateData.tools.filter(
|
||||
(t) => !existingToolSet.has(t) && t?.includes(Constants.mcp_delimiter),
|
||||
);
|
||||
|
||||
if (newMCPTools.length > 0) {
|
||||
const availableTools = (await getCachedTools()) ?? {};
|
||||
const approvedNew = await filterAuthorizedTools({
|
||||
tools: newMCPTools,
|
||||
userId: req.user.id,
|
||||
availableTools,
|
||||
});
|
||||
const rejectedSet = new Set(newMCPTools.filter((t) => !approvedNew.includes(t)));
|
||||
if (rejectedSet.size > 0) {
|
||||
updateData.tools = updateData.tools.filter((t) => !rejectedSet.has(t));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let updatedAgent =
|
||||
Object.keys(updateData).length > 0
|
||||
? await updateAgent({ id }, updateData, {
|
||||
? await db.updateAgent({ id }, updateData, {
|
||||
updatingUserId: req.user.id,
|
||||
})
|
||||
: existingAgent;
|
||||
|
|
@ -313,7 +441,7 @@ const duplicateAgentHandler = async (req, res) => {
|
|||
const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret'];
|
||||
|
||||
try {
|
||||
const agent = await getAgent({ id });
|
||||
const agent = await db.getAgent({ id });
|
||||
if (!agent) {
|
||||
return res.status(404).json({
|
||||
error: 'Agent not found',
|
||||
|
|
@ -361,7 +489,7 @@ const duplicateAgentHandler = async (req, res) => {
|
|||
});
|
||||
|
||||
const newActionsList = [];
|
||||
const originalActions = (await getActions({ agent_id: id }, true)) ?? [];
|
||||
const originalActions = (await db.getActions({ agent_id: id }, true)) ?? [];
|
||||
const promises = [];
|
||||
|
||||
/**
|
||||
|
|
@ -371,7 +499,7 @@ const duplicateAgentHandler = async (req, res) => {
|
|||
*/
|
||||
const duplicateAction = async (action) => {
|
||||
const newActionId = nanoid();
|
||||
const [domain] = action.action_id.split(actionDelimiter);
|
||||
const { domain } = action.metadata;
|
||||
const fullActionId = `${domain}${actionDelimiter}${newActionId}`;
|
||||
|
||||
// Sanitize sensitive metadata before persisting
|
||||
|
|
@ -380,8 +508,8 @@ const duplicateAgentHandler = async (req, res) => {
|
|||
delete filteredMetadata[field];
|
||||
}
|
||||
|
||||
const newAction = await updateAction(
|
||||
{ action_id: newActionId },
|
||||
const newAction = await db.updateAction(
|
||||
{ action_id: newActionId, agent_id: newAgentId },
|
||||
{
|
||||
metadata: filteredMetadata,
|
||||
agent_id: newAgentId,
|
||||
|
|
@ -403,7 +531,18 @@ const duplicateAgentHandler = async (req, res) => {
|
|||
|
||||
const agentActions = await Promise.all(promises);
|
||||
newAgentData.actions = agentActions;
|
||||
const newAgent = await createAgent(newAgentData);
|
||||
|
||||
if (newAgentData.tools?.length) {
|
||||
const availableTools = (await getCachedTools()) ?? {};
|
||||
newAgentData.tools = await filterAuthorizedTools({
|
||||
tools: newAgentData.tools,
|
||||
userId,
|
||||
availableTools,
|
||||
existingTools: newAgentData.tools,
|
||||
});
|
||||
}
|
||||
|
||||
const newAgent = await db.createAgent(newAgentData);
|
||||
|
||||
try {
|
||||
await Promise.all([
|
||||
|
|
@ -456,11 +595,11 @@ const duplicateAgentHandler = async (req, res) => {
|
|||
const deleteAgentHandler = async (req, res) => {
|
||||
try {
|
||||
const id = req.params.id;
|
||||
const agent = await getAgent({ id });
|
||||
const agent = await db.getAgent({ id });
|
||||
if (!agent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
await deleteAgent({ id });
|
||||
await db.deleteAgent({ id });
|
||||
return res.json({ message: 'Agent deleted' });
|
||||
} catch (error) {
|
||||
logger.error('[/Agents/:id] Error deleting Agent', error);
|
||||
|
|
@ -535,7 +674,7 @@ const getListAgentsHandler = async (req, res) => {
|
|||
cachedRefresh != null && typeof cachedRefresh === 'object' && cachedRefresh.urlCache != null;
|
||||
if (!isValidCachedRefresh) {
|
||||
try {
|
||||
const fullList = await getListAgentsByAccess({
|
||||
const fullList = await db.getListAgentsByAccess({
|
||||
accessibleIds,
|
||||
otherParams: {},
|
||||
limit: MAX_AVATAR_REFRESH_AGENTS,
|
||||
|
|
@ -545,7 +684,7 @@ const getListAgentsHandler = async (req, res) => {
|
|||
agents: fullList?.data ?? [],
|
||||
userId,
|
||||
refreshS3Url,
|
||||
updateAgent,
|
||||
updateAgent: db.updateAgent,
|
||||
});
|
||||
cachedRefresh = { urlCache };
|
||||
await cache.set(refreshKey, cachedRefresh, Time.THIRTY_MINUTES);
|
||||
|
|
@ -557,7 +696,7 @@ const getListAgentsHandler = async (req, res) => {
|
|||
}
|
||||
|
||||
// Use the new ACL-aware function
|
||||
const data = await getListAgentsByAccess({
|
||||
const data = await db.getListAgentsByAccess({
|
||||
accessibleIds,
|
||||
otherParams: filter,
|
||||
limit,
|
||||
|
|
@ -622,7 +761,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
|||
return res.status(400).json({ message: 'Agent ID is required' });
|
||||
}
|
||||
|
||||
const existingAgent = await getAgent({ id: agent_id });
|
||||
const existingAgent = await db.getAgent({ id: agent_id });
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
|
|
@ -654,7 +793,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
|||
const { deleteFile } = getStrategyFunctions(_avatar.source);
|
||||
try {
|
||||
await deleteFile(req, { filepath: _avatar.filepath });
|
||||
await deleteFileByFilter({ user: req.user.id, filepath: _avatar.filepath });
|
||||
await db.deleteFileByFilter({ user: req.user.id, filepath: _avatar.filepath });
|
||||
} catch (error) {
|
||||
logger.error('[/:agent_id/avatar] Error deleting old avatar', error);
|
||||
}
|
||||
|
|
@ -667,7 +806,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
|||
},
|
||||
};
|
||||
|
||||
const updatedAgent = await updateAgent({ id: agent_id }, data, {
|
||||
const updatedAgent = await db.updateAgent({ id: agent_id }, data, {
|
||||
updatingUserId: req.user.id,
|
||||
});
|
||||
|
||||
|
|
@ -723,7 +862,7 @@ const revertAgentVersionHandler = async (req, res) => {
|
|||
return res.status(400).json({ error: 'version_index is required' });
|
||||
}
|
||||
|
||||
const existingAgent = await getAgent({ id });
|
||||
const existingAgent = await db.getAgent({ id });
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
|
|
@ -731,7 +870,24 @@ const revertAgentVersionHandler = async (req, res) => {
|
|||
|
||||
// Permissions are enforced via route middleware (ACL EDIT)
|
||||
|
||||
const updatedAgent = await revertAgentVersion({ id }, version_index);
|
||||
let updatedAgent = await db.revertAgentVersion({ id }, version_index);
|
||||
|
||||
if (updatedAgent.tools?.length) {
|
||||
const availableTools = (await getCachedTools()) ?? {};
|
||||
const filteredTools = await filterAuthorizedTools({
|
||||
tools: updatedAgent.tools,
|
||||
userId: req.user.id,
|
||||
availableTools,
|
||||
existingTools: updatedAgent.tools,
|
||||
});
|
||||
if (filteredTools.length !== updatedAgent.tools.length) {
|
||||
updatedAgent = await db.updateAgent(
|
||||
{ id },
|
||||
{ tools: filteredTools },
|
||||
{ updatingUserId: req.user.id },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (updatedAgent.author) {
|
||||
updatedAgent.author = updatedAgent.author.toString();
|
||||
|
|
@ -755,8 +911,8 @@ const revertAgentVersionHandler = async (req, res) => {
|
|||
*/
|
||||
const getAgentCategories = async (_req, res) => {
|
||||
try {
|
||||
const categories = await getCategoriesWithCounts();
|
||||
const promotedCount = await countPromotedAgents();
|
||||
const categories = await db.getCategoriesWithCounts();
|
||||
const promotedCount = await db.countPromotedAgents();
|
||||
const formattedCategories = categories.map((category) => ({
|
||||
value: category.value,
|
||||
label: category.label,
|
||||
|
|
@ -799,4 +955,5 @@ module.exports = {
|
|||
uploadAgentAvatar: uploadAgentAvatarHandler,
|
||||
revertAgentVersion: revertAgentVersionHandler,
|
||||
getAgentCategories,
|
||||
filterAuthorizedTools,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ const mongoose = require('mongoose');
|
|||
const { nanoid } = require('nanoid');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
const { FileSources } = require('librechat-data-provider');
|
||||
const { FileSources, PermissionBits } = require('librechat-data-provider');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
// Only mock the dependencies that are not database-related
|
||||
|
|
@ -14,10 +14,6 @@ jest.mock('~/server/services/Config', () => ({
|
|||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Project', () => ({
|
||||
getProjectByName: jest.fn().mockResolvedValue(null),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(),
|
||||
}));
|
||||
|
|
@ -26,7 +22,16 @@ jest.mock('~/server/services/Files/images/avatar', () => ({
|
|||
resizeAvatar: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||
jest.mock('sharp', () =>
|
||||
jest.fn(() => ({
|
||||
metadata: jest.fn().mockResolvedValue({}),
|
||||
toFormat: jest.fn().mockReturnThis(),
|
||||
toBuffer: jest.fn().mockResolvedValue(Buffer.alloc(0)),
|
||||
})),
|
||||
);
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
refreshS3Url: jest.fn(),
|
||||
}));
|
||||
|
||||
|
|
@ -34,26 +39,26 @@ jest.mock('~/server/services/Files/process', () => ({
|
|||
filterFile: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Action', () => ({
|
||||
updateAction: jest.fn(),
|
||||
getActions: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/File', () => ({
|
||||
deleteFileByFilter: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/PermissionService', () => ({
|
||||
findAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||
findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||
getResourcePermissionsMap: jest.fn().mockResolvedValue(new Map()),
|
||||
grantPermission: jest.fn(),
|
||||
hasPublicPermission: jest.fn().mockResolvedValue(false),
|
||||
checkPermission: jest.fn().mockResolvedValue(true),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getCategoriesWithCounts: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models', () => {
|
||||
const mongoose = require('mongoose');
|
||||
const { createMethods } = require('@librechat/data-schemas');
|
||||
const methods = createMethods(mongoose, {
|
||||
removeAllPermissions: jest.fn().mockResolvedValue(undefined),
|
||||
});
|
||||
return {
|
||||
...methods,
|
||||
getCategoriesWithCounts: jest.fn(),
|
||||
deleteFileByFilter: jest.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
// Mock cache for S3 avatar refresh tests
|
||||
const mockCache = {
|
||||
|
|
@ -74,9 +79,10 @@ const {
|
|||
const {
|
||||
findAccessibleResources,
|
||||
findPubliclyAccessibleResources,
|
||||
getResourcePermissionsMap,
|
||||
} = require('~/server/services/PermissionService');
|
||||
|
||||
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
|
||||
const { refreshS3Url } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
|
||||
|
|
@ -175,7 +181,6 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
|||
// Unauthorized fields that should be stripped
|
||||
author: new mongoose.Types.ObjectId().toString(), // Should not be able to set author
|
||||
authorName: 'Hacker', // Should be stripped
|
||||
isCollaborative: true, // Should be stripped on creation
|
||||
versions: [], // Should be stripped
|
||||
_id: new mongoose.Types.ObjectId(), // Should be stripped
|
||||
id: 'custom_agent_id', // Should be overridden
|
||||
|
|
@ -194,7 +199,6 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
|||
// Verify unauthorized fields were not set
|
||||
expect(createdAgent.author.toString()).toBe(mockReq.user.id); // Should be the request user, not the malicious value
|
||||
expect(createdAgent.authorName).toBeUndefined();
|
||||
expect(createdAgent.isCollaborative).toBeFalsy();
|
||||
expect(createdAgent.versions).toHaveLength(1); // Should have exactly 1 version from creation
|
||||
expect(createdAgent.id).not.toBe('custom_agent_id'); // Should have generated ID
|
||||
expect(createdAgent.id).toMatch(/^agent_/); // Should have proper prefix
|
||||
|
|
@ -445,7 +449,6 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
|||
model: 'gpt-3.5-turbo',
|
||||
author: existingAgentAuthorId,
|
||||
description: 'Original description',
|
||||
isCollaborative: false,
|
||||
versions: [
|
||||
{
|
||||
name: 'Original Agent',
|
||||
|
|
@ -467,7 +470,6 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
|||
name: 'Updated Agent',
|
||||
description: 'Updated description',
|
||||
model: 'gpt-4',
|
||||
isCollaborative: true, // This IS allowed in updates
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
|
@ -480,13 +482,11 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
|||
expect(updatedAgent.name).toBe('Updated Agent');
|
||||
expect(updatedAgent.description).toBe('Updated description');
|
||||
expect(updatedAgent.model).toBe('gpt-4');
|
||||
expect(updatedAgent.isCollaborative).toBe(true);
|
||||
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString());
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Updated Agent');
|
||||
expect(agentInDb.isCollaborative).toBe(true);
|
||||
});
|
||||
|
||||
test('should reject update with unauthorized fields (mass assignment protection)', async () => {
|
||||
|
|
@ -541,26 +541,6 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
|||
expect(updatedAgent.name).toBe('Admin Update');
|
||||
});
|
||||
|
||||
test('should handle projectIds updates', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
|
||||
const projectId1 = new mongoose.Types.ObjectId().toString();
|
||||
const projectId2 = new mongoose.Types.ObjectId().toString();
|
||||
|
||||
mockReq.body = {
|
||||
projectIds: [projectId1, projectId2],
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent).toBeDefined();
|
||||
// Note: updateAgentProjects requires more setup, so we just verify the handler doesn't crash
|
||||
});
|
||||
|
||||
test('should validate tool_resources in updates', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
|
|
@ -1647,4 +1627,112 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
|||
expect(agent.avatar.filepath).toBe('old-s3-path.jpg');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge ACL validation', () => {
|
||||
let targetAgent;
|
||||
|
||||
beforeEach(async () => {
|
||||
targetAgent = await Agent.create({
|
||||
id: `agent_${nanoid()}`,
|
||||
author: new mongoose.Types.ObjectId().toString(),
|
||||
name: 'Target Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: [],
|
||||
});
|
||||
});
|
||||
|
||||
test('createAgentHandler should return 403 when user lacks VIEW on an edge-referenced agent', async () => {
|
||||
const permMap = new Map();
|
||||
getResourcePermissionsMap.mockResolvedValueOnce(permMap);
|
||||
|
||||
mockReq.body = {
|
||||
name: 'Attacker Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
edges: [{ from: 'self_placeholder', to: targetAgent.id, edgeType: 'handoff' }],
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.agent_ids).toContain(targetAgent.id);
|
||||
});
|
||||
|
||||
test('createAgentHandler should succeed when user has VIEW on all edge-referenced agents', async () => {
|
||||
const permMap = new Map([[targetAgent._id.toString(), 1]]);
|
||||
getResourcePermissionsMap.mockResolvedValueOnce(permMap);
|
||||
|
||||
mockReq.body = {
|
||||
name: 'Legit Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
edges: [{ from: 'self_placeholder', to: targetAgent.id, edgeType: 'handoff' }],
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
});
|
||||
|
||||
test('createAgentHandler should allow edges referencing non-existent agents (self-reference at create time)', async () => {
|
||||
mockReq.body = {
|
||||
name: 'Self-Ref Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
edges: [{ from: 'agent_does_not_exist_yet', to: 'agent_also_new', edgeType: 'handoff' }],
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
});
|
||||
|
||||
test('updateAgentHandler should return 403 when user lacks VIEW on an edge-referenced agent', async () => {
|
||||
const ownedAgent = await Agent.create({
|
||||
id: `agent_${nanoid()}`,
|
||||
author: mockReq.user.id,
|
||||
name: 'Owned Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: [],
|
||||
});
|
||||
|
||||
const permMap = new Map([[ownedAgent._id.toString(), PermissionBits.VIEW]]);
|
||||
getResourcePermissionsMap.mockResolvedValueOnce(permMap);
|
||||
|
||||
mockReq.params = { id: ownedAgent.id };
|
||||
mockReq.body = {
|
||||
edges: [{ from: ownedAgent.id, to: targetAgent.id, edgeType: 'handoff' }],
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.agent_ids).toContain(targetAgent.id);
|
||||
expect(response.agent_ids).not.toContain(ownedAgent.id);
|
||||
});
|
||||
|
||||
test('updateAgentHandler should succeed when edges field is absent from payload', async () => {
|
||||
const ownedAgent = await Agent.create({
|
||||
id: `agent_${nanoid()}`,
|
||||
author: mockReq.user.id,
|
||||
name: 'Owned Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: [],
|
||||
});
|
||||
|
||||
mockReq.params = { id: ownedAgent.id };
|
||||
mockReq.body = { name: 'Renamed Agent' };
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.name).toBe('Renamed Agent');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { sendEvent, getBalanceConfig, getModelMaxTokens, countTokens } = require('@librechat/api');
|
||||
const {
|
||||
sendEvent,
|
||||
countTokens,
|
||||
checkBalance,
|
||||
getBalanceConfig,
|
||||
getModelMaxTokens,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
|
|
@ -31,10 +37,14 @@ const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
|||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { createRunBody } = require('~/server/services/createRunBody');
|
||||
const { sendResponse } = require('~/server/middleware/error');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const {
|
||||
createAutoRefillTransaction,
|
||||
findBalanceByUser,
|
||||
getTransactions,
|
||||
getMultiplier,
|
||||
getConvo,
|
||||
} = require('~/models');
|
||||
const { logViolation, getLogStores } = require('~/cache');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
|
||||
/**
|
||||
|
|
@ -275,16 +285,19 @@ const chatV1 = async (req, res) => {
|
|||
// Count tokens up to the current context window
|
||||
promptTokens = Math.min(promptTokens, getModelMaxTokens(model));
|
||||
|
||||
await checkBalance({
|
||||
req,
|
||||
res,
|
||||
txData: {
|
||||
model,
|
||||
user: req.user.id,
|
||||
tokenType: 'prompt',
|
||||
amount: promptTokens,
|
||||
await checkBalance(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
txData: {
|
||||
model,
|
||||
user: req.user.id,
|
||||
tokenType: 'prompt',
|
||||
amount: promptTokens,
|
||||
},
|
||||
},
|
||||
});
|
||||
{ findBalanceByUser, getMultiplier, createAutoRefillTransaction, logViolation },
|
||||
);
|
||||
};
|
||||
|
||||
const { openai: _openai } = await getOpenAIClient({
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { sendEvent, getBalanceConfig, getModelMaxTokens, countTokens } = require('@librechat/api');
|
||||
const {
|
||||
sendEvent,
|
||||
countTokens,
|
||||
checkBalance,
|
||||
getBalanceConfig,
|
||||
getModelMaxTokens,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
|
|
@ -26,10 +32,14 @@ const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
|||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { createRunBody } = require('~/server/services/createRunBody');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const {
|
||||
getConvo,
|
||||
getMultiplier,
|
||||
getTransactions,
|
||||
findBalanceByUser,
|
||||
createAutoRefillTransaction,
|
||||
} = require('~/models');
|
||||
const { logViolation, getLogStores } = require('~/cache');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
|
||||
/**
|
||||
|
|
@ -148,16 +158,19 @@ const chatV2 = async (req, res) => {
|
|||
// Count tokens up to the current context window
|
||||
promptTokens = Math.min(promptTokens, getModelMaxTokens(model));
|
||||
|
||||
await checkBalance({
|
||||
req,
|
||||
res,
|
||||
txData: {
|
||||
model,
|
||||
user: req.user.id,
|
||||
tokenType: 'prompt',
|
||||
amount: promptTokens,
|
||||
await checkBalance(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
txData: {
|
||||
model,
|
||||
user: req.user.id,
|
||||
tokenType: 'prompt',
|
||||
amount: promptTokens,
|
||||
},
|
||||
},
|
||||
});
|
||||
{ findBalanceByUser, getMultiplier, createAutoRefillTransaction, logViolation },
|
||||
);
|
||||
};
|
||||
|
||||
const { openai: _openai } = await getOpenAIClient({
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ const { logger } = require('@librechat/data-schemas');
|
|||
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
|
||||
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
|
||||
const { sendResponse } = require('~/server/middleware/error');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getConvo } = require('~/models');
|
||||
|
||||
/**
|
||||
* @typedef {Object} ErrorHandlerContext
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
const {
|
||||
SystemRoles,
|
||||
EModelEndpoint,
|
||||
defaultOrderQuery,
|
||||
defaultAssistantsVersion,
|
||||
} = require('librechat-data-provider');
|
||||
const { logger, SystemCapabilities } = require('@librechat/data-schemas');
|
||||
const {
|
||||
initializeClient: initAzureClient,
|
||||
} = require('~/server/services/Endpoints/azureAssistants');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { hasCapability } = require('~/server/middleware/roles/capabilities');
|
||||
const { getEndpointsConfig } = require('~/server/services/Config');
|
||||
|
||||
/**
|
||||
|
|
@ -236,9 +237,19 @@ const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
|
|||
body = await listAssistantsForAzure({ req, res, version, azureConfig, query });
|
||||
}
|
||||
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
if (!appConfig.endpoints?.[endpoint]) {
|
||||
return body;
|
||||
} else if (!appConfig.endpoints?.[endpoint]) {
|
||||
}
|
||||
|
||||
let canManageAssistants = false;
|
||||
try {
|
||||
canManageAssistants = await hasCapability(req.user, SystemCapabilities.MANAGE_ASSISTANTS);
|
||||
} catch (err) {
|
||||
logger.warn(`[fetchAssistants] capability check failed, denying bypass: ${err.message}`);
|
||||
}
|
||||
|
||||
if (canManageAssistants) {
|
||||
logger.debug(`[fetchAssistants] MANAGE_ASSISTANTS bypass for user ${req.user.id}`);
|
||||
return body;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
const fs = require('fs').promises;
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { FileContext } = require('librechat-data-provider');
|
||||
const { deleteFileByFilter, updateAssistantDoc, getAssistants } = require('~/models');
|
||||
const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { deleteAssistantActions } = require('~/server/services/ActionService');
|
||||
const { updateAssistantDoc, getAssistants } = require('~/models/Assistant');
|
||||
const { getOpenAIClient, fetchAssistants } = require('./helpers');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { manifestToolMap } = require('~/app/clients/tools');
|
||||
const { deleteFileByFilter } = require('~/models');
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ const { ToolCallTypes } = require('librechat-data-provider');
|
|||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { validateAndUpdateTool } = require('~/server/services/ActionService');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { updateAssistantDoc } = require('~/models/Assistant');
|
||||
const { manifestToolMap } = require('~/app/clients/tools');
|
||||
const { updateAssistantDoc } = require('~/models');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -4,11 +4,27 @@ const { logger } = require('@librechat/data-schemas');
|
|||
const { logoutUser } = require('~/server/services/AuthService');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
|
||||
/** Parses and validates OPENID_MAX_LOGOUT_URL_LENGTH, returning defaultValue on invalid input */
|
||||
function parseMaxLogoutUrlLength(defaultValue = 2000) {
|
||||
const raw = process.env.OPENID_MAX_LOGOUT_URL_LENGTH;
|
||||
const trimmed = raw == null ? '' : raw.trim();
|
||||
if (trimmed === '') {
|
||||
return defaultValue;
|
||||
}
|
||||
const parsed = /^\d+$/.test(trimmed) ? Number(trimmed) : NaN;
|
||||
if (!Number.isFinite(parsed) || parsed <= 0) {
|
||||
logger.warn(
|
||||
`[logoutController] Invalid OPENID_MAX_LOGOUT_URL_LENGTH value "${raw}", using default ${defaultValue}`,
|
||||
);
|
||||
return defaultValue;
|
||||
}
|
||||
return parsed;
|
||||
}
|
||||
|
||||
const logoutController = async (req, res) => {
|
||||
const parsedCookies = req.headers.cookie ? cookies.parse(req.headers.cookie) : {};
|
||||
const isOpenIdUser = req.user?.openidId != null && req.user?.provider === 'openid';
|
||||
|
||||
/** For OpenID users, read tokens from session (with cookie fallback) */
|
||||
let refreshToken;
|
||||
let idToken;
|
||||
if (isOpenIdUser && req.session?.openidTokens) {
|
||||
|
|
@ -44,22 +60,64 @@ const logoutController = async (req, res) => {
|
|||
const endSessionEndpoint = openIdConfig.serverMetadata().end_session_endpoint;
|
||||
if (endSessionEndpoint) {
|
||||
const endSessionUrl = new URL(endSessionEndpoint);
|
||||
/** Redirect back to app's login page after IdP logout */
|
||||
const postLogoutRedirectUri =
|
||||
process.env.OPENID_POST_LOGOUT_REDIRECT_URI || `${process.env.DOMAIN_CLIENT}/login`;
|
||||
endSessionUrl.searchParams.set('post_logout_redirect_uri', postLogoutRedirectUri);
|
||||
|
||||
/** Add id_token_hint (preferred) or client_id for OIDC spec compliance */
|
||||
/**
|
||||
* OIDC RP-Initiated Logout cascading strategy:
|
||||
* 1. id_token_hint (most secure, identifies exact session)
|
||||
* 2. logout_hint + client_id (when URL would exceed safe length)
|
||||
* 3. client_id only (when no token available)
|
||||
*
|
||||
* JWT tokens from spec-compliant OIDC providers use base64url
|
||||
* encoding (RFC 7515), whose characters are all URL-safe, so
|
||||
* token length equals URL-encoded length for projection.
|
||||
* Non-compliant issuers using standard base64 (+/=) will cause
|
||||
* underestimation; increase OPENID_MAX_LOGOUT_URL_LENGTH if the
|
||||
* fallback does not trigger as expected.
|
||||
*/
|
||||
const maxLogoutUrlLength = parseMaxLogoutUrlLength();
|
||||
let strategy = 'no_token';
|
||||
if (idToken) {
|
||||
const baseLength = endSessionUrl.toString().length;
|
||||
const projectedLength = baseLength + '&id_token_hint='.length + idToken.length;
|
||||
if (projectedLength > maxLogoutUrlLength) {
|
||||
strategy = 'too_long';
|
||||
logger.debug(
|
||||
`[logoutController] Logout URL too long (${projectedLength} chars, max ${maxLogoutUrlLength}), ` +
|
||||
'switching to logout_hint strategy',
|
||||
);
|
||||
} else {
|
||||
strategy = 'use_token';
|
||||
}
|
||||
}
|
||||
|
||||
if (strategy === 'use_token') {
|
||||
endSessionUrl.searchParams.set('id_token_hint', idToken);
|
||||
} else if (process.env.OPENID_CLIENT_ID) {
|
||||
endSessionUrl.searchParams.set('client_id', process.env.OPENID_CLIENT_ID);
|
||||
} else {
|
||||
logger.warn(
|
||||
'[logoutController] Neither id_token_hint nor OPENID_CLIENT_ID is available. ' +
|
||||
'To enable id_token_hint, set OPENID_REUSE_TOKENS=true. ' +
|
||||
'The OIDC end-session request may be rejected by the identity provider.',
|
||||
);
|
||||
if (strategy === 'too_long') {
|
||||
const logoutHint = req.user?.email || req.user?.username || req.user?.openidId;
|
||||
if (logoutHint) {
|
||||
endSessionUrl.searchParams.set('logout_hint', logoutHint);
|
||||
}
|
||||
}
|
||||
|
||||
if (process.env.OPENID_CLIENT_ID) {
|
||||
endSessionUrl.searchParams.set('client_id', process.env.OPENID_CLIENT_ID);
|
||||
} else if (strategy === 'too_long') {
|
||||
logger.warn(
|
||||
'[logoutController] Logout URL exceeds max length and OPENID_CLIENT_ID is not set. ' +
|
||||
'The OIDC end-session request may be rejected. ' +
|
||||
'Consider setting OPENID_CLIENT_ID or increasing OPENID_MAX_LOGOUT_URL_LENGTH.',
|
||||
);
|
||||
} else {
|
||||
logger.warn(
|
||||
'[logoutController] Neither id_token_hint nor OPENID_CLIENT_ID is available. ' +
|
||||
'To enable id_token_hint, set OPENID_REUSE_TOKENS=true. ' +
|
||||
'The OIDC end-session request may be rejected by the identity provider.',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
response.redirect = endSessionUrl.toString();
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
const cookies = require('cookie');
|
||||
|
||||
const mockLogoutUser = jest.fn();
|
||||
const mockLogger = { warn: jest.fn(), error: jest.fn() };
|
||||
const mockLogger = { warn: jest.fn(), error: jest.fn(), debug: jest.fn() };
|
||||
const mockIsEnabled = jest.fn();
|
||||
const mockGetOpenIdConfig = jest.fn();
|
||||
|
||||
|
|
@ -256,4 +256,312 @@ describe('LogoutController', () => {
|
|||
expect(res.clearCookie).toHaveBeenCalledWith('token_provider');
|
||||
});
|
||||
});
|
||||
|
||||
describe('URL length limit and logout_hint fallback', () => {
|
||||
it('uses logout_hint when id_token makes URL exceed default limit (2000 chars)', async () => {
|
||||
const longIdToken = 'a'.repeat(3000);
|
||||
const req = buildReq({
|
||||
user: { _id: 'user1', openidId: 'oid1', provider: 'openid', email: 'user@example.com' },
|
||||
session: {
|
||||
openidTokens: { refreshToken: 'srt', idToken: longIdToken },
|
||||
destroy: jest.fn(),
|
||||
},
|
||||
});
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).not.toContain('id_token_hint=');
|
||||
expect(body.redirect).toContain('logout_hint=user%40example.com');
|
||||
expect(body.redirect).toContain('client_id=my-client-id');
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(expect.stringContaining('Logout URL too long'));
|
||||
});
|
||||
|
||||
it('uses id_token_hint when URL is within default limit', async () => {
|
||||
const shortIdToken = 'short-token';
|
||||
const req = buildReq({
|
||||
session: {
|
||||
openidTokens: { refreshToken: 'srt', idToken: shortIdToken },
|
||||
destroy: jest.fn(),
|
||||
},
|
||||
});
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).toContain('id_token_hint=short-token');
|
||||
expect(body.redirect).not.toContain('logout_hint=');
|
||||
expect(body.redirect).not.toContain('client_id=');
|
||||
});
|
||||
|
||||
it('respects custom OPENID_MAX_LOGOUT_URL_LENGTH', async () => {
|
||||
process.env.OPENID_MAX_LOGOUT_URL_LENGTH = '500';
|
||||
const mediumIdToken = 'a'.repeat(600);
|
||||
const req = buildReq({
|
||||
user: { _id: 'user1', openidId: 'oid1', provider: 'openid', email: 'user@example.com' },
|
||||
session: {
|
||||
openidTokens: { refreshToken: 'srt', idToken: mediumIdToken },
|
||||
destroy: jest.fn(),
|
||||
},
|
||||
});
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).not.toContain('id_token_hint=');
|
||||
expect(body.redirect).toContain('logout_hint=user%40example.com');
|
||||
});
|
||||
|
||||
it('uses username as logout_hint when email is not available', async () => {
|
||||
const longIdToken = 'a'.repeat(3000);
|
||||
const req = buildReq({
|
||||
user: {
|
||||
_id: 'user1',
|
||||
openidId: 'oid1',
|
||||
provider: 'openid',
|
||||
username: 'testuser',
|
||||
},
|
||||
session: {
|
||||
openidTokens: { refreshToken: 'srt', idToken: longIdToken },
|
||||
destroy: jest.fn(),
|
||||
},
|
||||
});
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).toContain('logout_hint=testuser');
|
||||
});
|
||||
|
||||
it('uses openidId as logout_hint when email and username are not available', async () => {
|
||||
const longIdToken = 'a'.repeat(3000);
|
||||
const req = buildReq({
|
||||
user: { _id: 'user1', openidId: 'unique-oid-123', provider: 'openid' },
|
||||
session: {
|
||||
openidTokens: { refreshToken: 'srt', idToken: longIdToken },
|
||||
destroy: jest.fn(),
|
||||
},
|
||||
});
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).toContain('logout_hint=unique-oid-123');
|
||||
});
|
||||
|
||||
it('uses openidId as logout_hint when email and username are explicitly null', async () => {
|
||||
const longIdToken = 'a'.repeat(3000);
|
||||
const req = buildReq({
|
||||
user: {
|
||||
_id: 'user1',
|
||||
openidId: 'oid-without-email',
|
||||
provider: 'openid',
|
||||
email: null,
|
||||
username: null,
|
||||
},
|
||||
session: {
|
||||
openidTokens: { refreshToken: 'srt', idToken: longIdToken },
|
||||
destroy: jest.fn(),
|
||||
},
|
||||
});
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).not.toContain('id_token_hint=');
|
||||
expect(body.redirect).toContain('logout_hint=oid-without-email');
|
||||
expect(body.redirect).toContain('client_id=my-client-id');
|
||||
});
|
||||
|
||||
it('uses only client_id when absolutely no hint is available', async () => {
|
||||
const longIdToken = 'a'.repeat(3000);
|
||||
const req = buildReq({
|
||||
user: {
|
||||
_id: 'user1',
|
||||
openidId: '',
|
||||
provider: 'openid',
|
||||
email: '',
|
||||
username: '',
|
||||
},
|
||||
session: {
|
||||
openidTokens: { refreshToken: 'srt', idToken: longIdToken },
|
||||
destroy: jest.fn(),
|
||||
},
|
||||
});
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).not.toContain('id_token_hint=');
|
||||
expect(body.redirect).not.toContain('logout_hint=');
|
||||
expect(body.redirect).toContain('client_id=my-client-id');
|
||||
});
|
||||
|
||||
it('warns about missing OPENID_CLIENT_ID when URL is too long', async () => {
|
||||
delete process.env.OPENID_CLIENT_ID;
|
||||
const longIdToken = 'a'.repeat(3000);
|
||||
const req = buildReq({
|
||||
user: { _id: 'user1', openidId: 'oid1', provider: 'openid', email: 'user@example.com' },
|
||||
session: {
|
||||
openidTokens: { refreshToken: 'srt', idToken: longIdToken },
|
||||
destroy: jest.fn(),
|
||||
},
|
||||
});
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).not.toContain('id_token_hint=');
|
||||
expect(body.redirect).toContain('logout_hint=');
|
||||
expect(body.redirect).not.toContain('client_id=');
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('OPENID_CLIENT_ID is not set'),
|
||||
);
|
||||
});
|
||||
|
||||
it('falls back to logout_hint for cookie-sourced long token', async () => {
|
||||
const longCookieToken = 'a'.repeat(3000);
|
||||
cookies.parse.mockReturnValue({
|
||||
refreshToken: 'cookie-rt',
|
||||
openid_id_token: longCookieToken,
|
||||
});
|
||||
const req = buildReq({
|
||||
user: { _id: 'user1', openidId: 'oid1', provider: 'openid', email: 'user@example.com' },
|
||||
session: { destroy: jest.fn() },
|
||||
});
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).not.toContain('id_token_hint=');
|
||||
expect(body.redirect).toContain('logout_hint=user%40example.com');
|
||||
expect(body.redirect).toContain('client_id=my-client-id');
|
||||
});
|
||||
|
||||
it('keeps id_token_hint when projected URL length equals the max', async () => {
|
||||
const baseUrl = new URL('https://idp.example.com/logout');
|
||||
baseUrl.searchParams.set('post_logout_redirect_uri', 'https://app.example.com/login');
|
||||
const baseLength = baseUrl.toString().length;
|
||||
const tokenLength = 2000 - baseLength - '&id_token_hint='.length;
|
||||
const exactToken = 'a'.repeat(tokenLength);
|
||||
|
||||
const req = buildReq({
|
||||
session: {
|
||||
openidTokens: { refreshToken: 'srt', idToken: exactToken },
|
||||
destroy: jest.fn(),
|
||||
},
|
||||
});
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).toContain('id_token_hint=');
|
||||
expect(body.redirect).not.toContain('logout_hint=');
|
||||
});
|
||||
|
||||
it('falls back to logout_hint when projected URL is one char over the max', async () => {
|
||||
const baseUrl = new URL('https://idp.example.com/logout');
|
||||
baseUrl.searchParams.set('post_logout_redirect_uri', 'https://app.example.com/login');
|
||||
const baseLength = baseUrl.toString().length;
|
||||
const tokenLength = 2000 - baseLength - '&id_token_hint='.length + 1;
|
||||
const overToken = 'a'.repeat(tokenLength);
|
||||
|
||||
const req = buildReq({
|
||||
user: { _id: 'user1', openidId: 'oid1', provider: 'openid', email: 'user@example.com' },
|
||||
session: {
|
||||
openidTokens: { refreshToken: 'srt', idToken: overToken },
|
||||
destroy: jest.fn(),
|
||||
},
|
||||
});
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).not.toContain('id_token_hint=');
|
||||
expect(body.redirect).toContain('logout_hint=');
|
||||
});
|
||||
});
|
||||
|
||||
describe('invalid OPENID_MAX_LOGOUT_URL_LENGTH values', () => {
|
||||
it('silently uses default when value is empty', async () => {
|
||||
process.env.OPENID_MAX_LOGOUT_URL_LENGTH = '';
|
||||
const req = buildReq();
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
expect(mockLogger.warn).not.toHaveBeenCalledWith(
|
||||
expect.stringContaining('Invalid OPENID_MAX_LOGOUT_URL_LENGTH'),
|
||||
);
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).toContain('id_token_hint=small-id-token');
|
||||
});
|
||||
|
||||
it('warns and uses default for partial numeric string', async () => {
|
||||
process.env.OPENID_MAX_LOGOUT_URL_LENGTH = '500abc';
|
||||
const req = buildReq();
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Invalid OPENID_MAX_LOGOUT_URL_LENGTH'),
|
||||
);
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).toContain('id_token_hint=small-id-token');
|
||||
});
|
||||
|
||||
it('warns and uses default for zero value', async () => {
|
||||
process.env.OPENID_MAX_LOGOUT_URL_LENGTH = '0';
|
||||
const req = buildReq();
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Invalid OPENID_MAX_LOGOUT_URL_LENGTH'),
|
||||
);
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).toContain('id_token_hint=small-id-token');
|
||||
});
|
||||
|
||||
it('warns and uses default for negative value', async () => {
|
||||
process.env.OPENID_MAX_LOGOUT_URL_LENGTH = '-1';
|
||||
const req = buildReq();
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Invalid OPENID_MAX_LOGOUT_URL_LENGTH'),
|
||||
);
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).toContain('id_token_hint=small-id-token');
|
||||
});
|
||||
|
||||
it('warns and uses default for non-numeric string', async () => {
|
||||
process.env.OPENID_MAX_LOGOUT_URL_LENGTH = 'abc';
|
||||
const req = buildReq();
|
||||
const res = buildRes();
|
||||
|
||||
await logoutController(req, res);
|
||||
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Invalid OPENID_MAX_LOGOUT_URL_LENGTH'),
|
||||
);
|
||||
const body = res.send.mock.calls[0][0];
|
||||
expect(body.redirect).toContain('id_token_hint=small-id-token');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -7,9 +7,11 @@
|
|||
*/
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
MCPErrorCodes,
|
||||
redactServerSecrets,
|
||||
redactAllServerSecrets,
|
||||
isMCPDomainNotAllowedError,
|
||||
isMCPInspectionFailedError,
|
||||
MCPErrorCodes,
|
||||
} = require('@librechat/api');
|
||||
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
|
||||
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
|
||||
|
|
@ -181,10 +183,8 @@ const getMCPServersList = async (req, res) => {
|
|||
return res.status(401).json({ message: 'Unauthorized' });
|
||||
}
|
||||
|
||||
// 2. Get all server configs from registry (YAML + DB)
|
||||
const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId);
|
||||
|
||||
return res.json(serverConfigs);
|
||||
return res.json(redactAllServerSecrets(serverConfigs));
|
||||
} catch (error) {
|
||||
logger.error('[getMCPServersList]', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
|
|
@ -215,7 +215,7 @@ const createMCPServerController = async (req, res) => {
|
|||
);
|
||||
res.status(201).json({
|
||||
serverName: result.serverName,
|
||||
...result.config,
|
||||
...redactServerSecrets(result.config),
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[createMCPServer]', error);
|
||||
|
|
@ -243,7 +243,7 @@ const getMCPServerById = async (req, res) => {
|
|||
return res.status(404).json({ message: 'MCP server not found' });
|
||||
}
|
||||
|
||||
res.status(200).json(parsedConfig);
|
||||
res.status(200).json(redactServerSecrets(parsedConfig));
|
||||
} catch (error) {
|
||||
logger.error('[getMCPServerById]', error);
|
||||
res.status(500).json({ message: error.message });
|
||||
|
|
@ -274,7 +274,7 @@ const updateMCPServerController = async (req, res) => {
|
|||
userId,
|
||||
);
|
||||
|
||||
res.status(200).json(parsedConfig);
|
||||
res.status(200).json(redactServerSecrets(parsedConfig));
|
||||
} catch (error) {
|
||||
logger.error('[updateMCPServer]', error);
|
||||
const mcpErrorResponse = handleMCPError(error, res);
|
||||
|
|
|
|||
|
|
@ -9,13 +9,11 @@ const {
|
|||
ToolCallTypes,
|
||||
PermissionTypes,
|
||||
} = require('librechat-data-provider');
|
||||
const { getRoleByName, createToolCall, getToolCallsByConvo, getMessage } = require('~/models');
|
||||
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
|
||||
const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||
const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { loadTools } = require('~/app/clients/tools/util');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { getMessage } = require('~/models/Message');
|
||||
|
||||
const fieldsMap = {
|
||||
[Tools.execute_code]: [EnvVar.CODE_API_KEY],
|
||||
|
|
|
|||
|
|
@ -24,14 +24,14 @@ const { connectDb, indexSync } = require('~/db');
|
|||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||
const createValidateImageRequest = require('./middleware/validateImageRequest');
|
||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||
const { updateInterfacePermissions } = require('~/models/interface');
|
||||
const { updateInterfacePermissions: updateInterfacePerms } = require('@librechat/api');
|
||||
const { getRoleByName, updateAccessPermissions, seedDatabase } = require('~/models');
|
||||
const { checkMigrations } = require('./services/start/migration');
|
||||
const initializeMCPs = require('./services/initializeMCPs');
|
||||
const configureSocialLogins = require('./socialLogins');
|
||||
const { getAppConfig } = require('./services/Config');
|
||||
const staticCache = require('./utils/staticCache');
|
||||
const noIndex = require('./middleware/noIndex');
|
||||
const { seedDatabase } = require('~/models');
|
||||
const routes = require('./routes');
|
||||
|
||||
const { PORT, HOST, ALLOW_SOCIAL_LOGIN, DISABLE_COMPRESSION, TRUST_PROXY } = process.env ?? {};
|
||||
|
|
@ -222,7 +222,7 @@ if (cluster.isMaster) {
|
|||
const appConfig = await getAppConfig();
|
||||
initializeFileStorage(appConfig);
|
||||
await performStartupChecks(appConfig);
|
||||
await updateInterfacePermissions(appConfig);
|
||||
await updateInterfacePerms({ appConfig, getRoleByName, updateAccessPermissions });
|
||||
|
||||
/** Load index.html for SPA serving */
|
||||
const indexPath = path.join(appConfig.paths.dist, 'index.html');
|
||||
|
|
|
|||
|
|
@ -20,19 +20,20 @@ const {
|
|||
GenerationJobManager,
|
||||
createStreamServices,
|
||||
initializeFileStorage,
|
||||
updateInterfacePermissions,
|
||||
} = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||
const { getRoleByName, updateAccessPermissions, seedDatabase } = require('~/models');
|
||||
const { capabilityContextMiddleware } = require('./middleware/roles/capabilities');
|
||||
const createValidateImageRequest = require('./middleware/validateImageRequest');
|
||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||
const { updateInterfacePermissions } = require('~/models/interface');
|
||||
const { checkMigrations } = require('./services/start/migration');
|
||||
const initializeMCPs = require('./services/initializeMCPs');
|
||||
const configureSocialLogins = require('./socialLogins');
|
||||
const { getAppConfig } = require('./services/Config');
|
||||
const staticCache = require('./utils/staticCache');
|
||||
const noIndex = require('./middleware/noIndex');
|
||||
const { seedDatabase } = require('~/models');
|
||||
const routes = require('./routes');
|
||||
|
||||
const { PORT, HOST, ALLOW_SOCIAL_LOGIN, DISABLE_COMPRESSION, TRUST_PROXY } = process.env ?? {};
|
||||
|
|
@ -62,7 +63,7 @@ const startServer = async () => {
|
|||
const appConfig = await getAppConfig();
|
||||
initializeFileStorage(appConfig);
|
||||
await performStartupChecks(appConfig);
|
||||
await updateInterfacePermissions(appConfig);
|
||||
await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions });
|
||||
|
||||
const indexPath = path.join(appConfig.paths.dist, 'index.html');
|
||||
let indexHTML = fs.readFileSync(indexPath, 'utf8');
|
||||
|
|
@ -133,6 +134,9 @@ const startServer = async () => {
|
|||
await configureSocialLogins(app);
|
||||
}
|
||||
|
||||
/* Per-request capability cache — must be registered before any route that calls hasCapability */
|
||||
app.use(capabilityContextMiddleware);
|
||||
|
||||
app.use('/oauth', routes.oauth);
|
||||
/* API Endpoints */
|
||||
app.use('/api/auth', routes.auth);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
||||
const {
|
||||
isEnabled,
|
||||
sendEvent,
|
||||
|
|
@ -7,14 +8,11 @@ const {
|
|||
recordCollectedUsage,
|
||||
sanitizeMessageForTransmit,
|
||||
} = require('@librechat/api');
|
||||
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
||||
const { saveMessage, getConvo, updateBalance, bulkInsertTransactions } = require('~/models');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const { getMultiplier, getCacheMultiplier } = require('~/models/tx');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const { sendError } = require('~/server/middleware/error');
|
||||
const { abortRun } = require('./abortRun');
|
||||
const db = require('~/models');
|
||||
|
||||
/**
|
||||
* Spend tokens for all models from collected usage.
|
||||
|
|
@ -44,10 +42,10 @@ async function spendCollectedUsage({
|
|||
|
||||
await recordCollectedUsage(
|
||||
{
|
||||
spendTokens,
|
||||
spendStructuredTokens,
|
||||
pricing: { getMultiplier, getCacheMultiplier },
|
||||
bulkWriteOps: { insertMany: bulkInsertTransactions, updateBalance },
|
||||
spendTokens: db.spendTokens,
|
||||
spendStructuredTokens: db.spendStructuredTokens,
|
||||
pricing: { getMultiplier: db.getMultiplier, getCacheMultiplier: db.getCacheMultiplier },
|
||||
bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance },
|
||||
},
|
||||
{
|
||||
user: userId,
|
||||
|
|
@ -123,20 +121,24 @@ async function abortMessage(req, res) {
|
|||
});
|
||||
} else {
|
||||
// Fallback: no collected usage, use text-based token counting for primary model only
|
||||
await spendTokens(
|
||||
await db.spendTokens(
|
||||
{ ...responseMessage, context: 'incomplete', user: userId },
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
}
|
||||
|
||||
await saveMessage(
|
||||
req,
|
||||
await db.saveMessage(
|
||||
{
|
||||
userId: req?.user?.id,
|
||||
isTemporary: req?.body?.isTemporary,
|
||||
interfaceConfig: req?.config?.interfaceConfig,
|
||||
},
|
||||
{ ...responseMessage, user: userId },
|
||||
{ context: 'api/server/middleware/abortMiddleware.js' },
|
||||
);
|
||||
|
||||
// Get conversation for title
|
||||
const conversation = await getConvo(userId, conversationId);
|
||||
const conversation = await db.getConvo(userId, conversationId);
|
||||
|
||||
const finalEvent = {
|
||||
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
|
||||
|
|
|
|||
|
|
@ -20,16 +20,6 @@ const mockRecordCollectedUsage = jest
|
|||
const mockGetMultiplier = jest.fn().mockReturnValue(1);
|
||||
const mockGetCacheMultiplier = jest.fn().mockReturnValue(null);
|
||||
|
||||
jest.mock('~/models/spendTokens', () => ({
|
||||
spendTokens: (...args) => mockSpendTokens(...args),
|
||||
spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/tx', () => ({
|
||||
getMultiplier: mockGetMultiplier,
|
||||
getCacheMultiplier: mockGetCacheMultiplier,
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
|
|
@ -73,6 +63,10 @@ jest.mock('~/models', () => ({
|
|||
getConvo: jest.fn().mockResolvedValue({ title: 'Test Chat' }),
|
||||
updateBalance: mockUpdateBalance,
|
||||
bulkInsertTransactions: mockBulkInsertTransactions,
|
||||
spendTokens: (...args) => mockSpendTokens(...args),
|
||||
spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args),
|
||||
getMultiplier: mockGetMultiplier,
|
||||
getCacheMultiplier: mockGetCacheMultiplier,
|
||||
}));
|
||||
|
||||
jest.mock('./abortRun', () => ({
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@ const { logger } = require('@librechat/data-schemas');
|
|||
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
|
||||
const { deleteMessages } = require('~/models/Message');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const { deleteMessages, getConvo } = require('~/models');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
const three_minutes = 1000 * 60 * 3;
|
||||
|
|
|
|||
|
|
@ -1,42 +1,145 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Constants,
|
||||
Permissions,
|
||||
ResourceType,
|
||||
SystemRoles,
|
||||
PermissionTypes,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const { checkPermission } = require('~/server/services/PermissionService');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const db = require('~/models');
|
||||
|
||||
const { getRoleByName, getAgent } = db;
|
||||
|
||||
/**
|
||||
* Agent ID resolver function for agent_id from request body
|
||||
* Resolves custom agent ID (e.g., "agent_abc123") to MongoDB ObjectId
|
||||
* This is used specifically for chat routes where agent_id comes from request body
|
||||
*
|
||||
* Resolves custom agent ID (e.g., "agent_abc123") to a MongoDB document.
|
||||
* @param {string} agentCustomId - Custom agent ID from request body
|
||||
* @returns {Promise<Object|null>} Agent document with _id field, or null if not found
|
||||
* @returns {Promise<Object|null>} Agent document with _id field, or null if ephemeral/not found
|
||||
*/
|
||||
const resolveAgentIdFromBody = async (agentCustomId) => {
|
||||
// Handle ephemeral agents - they don't need permission checks
|
||||
// Real agent IDs always start with "agent_", so anything else is ephemeral
|
||||
if (isEphemeralAgentId(agentCustomId)) {
|
||||
return null; // No permission check needed for ephemeral agents
|
||||
return null;
|
||||
}
|
||||
|
||||
return await getAgent({ id: agentCustomId });
|
||||
return getAgent({ id: agentCustomId });
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware factory that creates middleware to check agent access permissions from request body.
|
||||
* This middleware is specifically designed for chat routes where the agent_id comes from req.body
|
||||
* instead of route parameters.
|
||||
* Creates a `canAccessResource` middleware for the given agent ID
|
||||
* and chains to the provided continuation on success.
|
||||
*
|
||||
* @param {string} agentId - The agent's custom string ID (e.g., "agent_abc123")
|
||||
* @param {number} requiredPermission - Permission bit(s) required
|
||||
* @param {import('express').Request} req
|
||||
* @param {import('express').Response} res - Written on deny; continuation called on allow
|
||||
* @param {Function} continuation - Called when the permission check passes
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const checkAgentResourceAccess = (agentId, requiredPermission, req, res, continuation) => {
|
||||
const middleware = canAccessResource({
|
||||
resourceType: ResourceType.AGENT,
|
||||
requiredPermission,
|
||||
resourceIdParam: 'agent_id',
|
||||
idResolver: () => resolveAgentIdFromBody(agentId),
|
||||
});
|
||||
|
||||
const tempReq = {
|
||||
...req,
|
||||
params: { ...req.params, agent_id: agentId },
|
||||
};
|
||||
|
||||
return middleware(tempReq, res, continuation);
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware factory that validates MULTI_CONVO:USE role permission and, when
|
||||
* addedConvo.agent_id is a non-ephemeral agent, the same resource-level permission
|
||||
* required for the primary agent (`requiredPermission`). Caches the resolved agent
|
||||
* document on `req.resolvedAddedAgent` to avoid a duplicate DB fetch in `loadAddedAgent`.
|
||||
*
|
||||
* @param {number} requiredPermission - Permission bit(s) to check on the added agent resource
|
||||
* @returns {(req: import('express').Request, res: import('express').Response, next: Function) => Promise<void>}
|
||||
*/
|
||||
const checkAddedConvoAccess = (requiredPermission) => async (req, res, next) => {
|
||||
const addedConvo = req.body?.addedConvo;
|
||||
if (!addedConvo || typeof addedConvo !== 'object' || Array.isArray(addedConvo)) {
|
||||
return next();
|
||||
}
|
||||
|
||||
try {
|
||||
if (!req.user?.role) {
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions for multi-conversation',
|
||||
});
|
||||
}
|
||||
|
||||
if (req.user.role !== SystemRoles.ADMIN) {
|
||||
const role = await getRoleByName(req.user.role);
|
||||
const hasMultiConvo = role?.permissions?.[PermissionTypes.MULTI_CONVO]?.[Permissions.USE];
|
||||
if (!hasMultiConvo) {
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
message: 'Multi-conversation feature is not enabled',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const addedAgentId = addedConvo.agent_id;
|
||||
if (!addedAgentId || typeof addedAgentId !== 'string' || isEphemeralAgentId(addedAgentId)) {
|
||||
return next();
|
||||
}
|
||||
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const agent = await resolveAgentIdFromBody(addedAgentId);
|
||||
if (!agent) {
|
||||
return res.status(404).json({
|
||||
error: 'Not Found',
|
||||
message: `${ResourceType.AGENT} not found`,
|
||||
});
|
||||
}
|
||||
|
||||
const hasPermission = await checkPermission({
|
||||
userId: req.user.id,
|
||||
role: req.user.role,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
requiredPermission,
|
||||
});
|
||||
|
||||
if (!hasPermission) {
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
message: `Insufficient permissions to access this ${ResourceType.AGENT}`,
|
||||
});
|
||||
}
|
||||
|
||||
req.resolvedAddedAgent = agent;
|
||||
return next();
|
||||
} catch (error) {
|
||||
logger.error('Failed to validate addedConvo access permissions', error);
|
||||
return res.status(500).json({
|
||||
error: 'Internal Server Error',
|
||||
message: 'Failed to validate addedConvo access permissions',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware factory that checks agent access permissions from request body.
|
||||
* Validates both the primary agent_id and, when present, addedConvo.agent_id
|
||||
* (which also requires MULTI_CONVO:USE role permission).
|
||||
*
|
||||
* @param {Object} options - Configuration options
|
||||
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
|
||||
* @returns {Function} Express middleware function
|
||||
*
|
||||
* @example
|
||||
* // Basic usage for agent chat (requires VIEW permission)
|
||||
* router.post('/chat',
|
||||
* canAccessAgentFromBody({ requiredPermission: PermissionBits.VIEW }),
|
||||
* buildEndpointOption,
|
||||
|
|
@ -46,11 +149,12 @@ const resolveAgentIdFromBody = async (agentCustomId) => {
|
|||
const canAccessAgentFromBody = (options) => {
|
||||
const { requiredPermission } = options;
|
||||
|
||||
// Validate required options
|
||||
if (!requiredPermission || typeof requiredPermission !== 'number') {
|
||||
throw new Error('canAccessAgentFromBody: requiredPermission is required and must be a number');
|
||||
}
|
||||
|
||||
const addedConvoMiddleware = checkAddedConvoAccess(requiredPermission);
|
||||
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
const { endpoint, agent_id } = req.body;
|
||||
|
|
@ -67,28 +171,13 @@ const canAccessAgentFromBody = (options) => {
|
|||
});
|
||||
}
|
||||
|
||||
// Skip permission checks for ephemeral agents
|
||||
// Real agent IDs always start with "agent_", so anything else is ephemeral
|
||||
const afterPrimaryCheck = () => addedConvoMiddleware(req, res, next);
|
||||
|
||||
if (isEphemeralAgentId(agentId)) {
|
||||
return next();
|
||||
return afterPrimaryCheck();
|
||||
}
|
||||
|
||||
const agentAccessMiddleware = canAccessResource({
|
||||
resourceType: ResourceType.AGENT,
|
||||
requiredPermission,
|
||||
resourceIdParam: 'agent_id', // This will be ignored since we use custom resolver
|
||||
idResolver: () => resolveAgentIdFromBody(agentId),
|
||||
});
|
||||
|
||||
const tempReq = {
|
||||
...req,
|
||||
params: {
|
||||
...req.params,
|
||||
agent_id: agentId,
|
||||
},
|
||||
};
|
||||
|
||||
return agentAccessMiddleware(tempReq, res, next);
|
||||
return checkAgentResourceAccess(agentId, requiredPermission, req, res, afterPrimaryCheck);
|
||||
} catch (error) {
|
||||
logger.error('Failed to validate agent access permissions', error);
|
||||
return res.status(500).json({
|
||||
|
|
|
|||
|
|
@ -0,0 +1,509 @@
|
|||
const mongoose = require('mongoose');
|
||||
const {
|
||||
ResourceType,
|
||||
SystemRoles,
|
||||
PrincipalType,
|
||||
PrincipalModel,
|
||||
} = require('librechat-data-provider');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { canAccessAgentFromBody } = require('./canAccessAgentFromBody');
|
||||
const { User, Role, AclEntry } = require('~/db/models');
|
||||
const { createAgent } = require('~/models');
|
||||
|
||||
describe('canAccessAgentFromBody middleware', () => {
|
||||
let mongoServer;
|
||||
let req, res, next;
|
||||
let testUser, otherUser;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await mongoose.connection.dropDatabase();
|
||||
|
||||
await Role.create({
|
||||
name: 'test-role',
|
||||
permissions: {
|
||||
AGENTS: { USE: true, CREATE: true, SHARE: true },
|
||||
MULTI_CONVO: { USE: true },
|
||||
},
|
||||
});
|
||||
|
||||
await Role.create({
|
||||
name: 'no-multi-convo',
|
||||
permissions: {
|
||||
AGENTS: { USE: true, CREATE: true, SHARE: true },
|
||||
MULTI_CONVO: { USE: false },
|
||||
},
|
||||
});
|
||||
|
||||
await Role.create({
|
||||
name: SystemRoles.ADMIN,
|
||||
permissions: {
|
||||
AGENTS: { USE: true, CREATE: true, SHARE: true },
|
||||
MULTI_CONVO: { USE: true },
|
||||
},
|
||||
});
|
||||
|
||||
testUser = await User.create({
|
||||
email: 'test@example.com',
|
||||
name: 'Test User',
|
||||
username: 'testuser',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
otherUser = await User.create({
|
||||
email: 'other@example.com',
|
||||
name: 'Other User',
|
||||
username: 'otheruser',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
req = {
|
||||
user: { id: testUser._id, role: testUser.role },
|
||||
params: {},
|
||||
body: {
|
||||
endpoint: 'agents',
|
||||
agent_id: 'ephemeral_primary',
|
||||
},
|
||||
};
|
||||
res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
};
|
||||
next = jest.fn();
|
||||
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('middleware factory', () => {
|
||||
test('throws if requiredPermission is missing', () => {
|
||||
expect(() => canAccessAgentFromBody({})).toThrow(
|
||||
'canAccessAgentFromBody: requiredPermission is required and must be a number',
|
||||
);
|
||||
});
|
||||
|
||||
test('throws if requiredPermission is not a number', () => {
|
||||
expect(() => canAccessAgentFromBody({ requiredPermission: '1' })).toThrow(
|
||||
'canAccessAgentFromBody: requiredPermission is required and must be a number',
|
||||
);
|
||||
});
|
||||
|
||||
test('returns a middleware function', () => {
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
expect(typeof middleware).toBe('function');
|
||||
expect(middleware.length).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('primary agent checks', () => {
|
||||
test('returns 400 when agent_id is missing on agents endpoint', async () => {
|
||||
req.body.agent_id = undefined;
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(400);
|
||||
});
|
||||
|
||||
test('proceeds for ephemeral primary agent without addedConvo', async () => {
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('proceeds for non-agents endpoint (ephemeral fallback)', async () => {
|
||||
req.body.endpoint = 'openAI';
|
||||
req.body.agent_id = undefined;
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('addedConvo — absent or invalid shape', () => {
|
||||
test('calls next when addedConvo is absent', async () => {
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('calls next when addedConvo is a string', async () => {
|
||||
req.body.addedConvo = 'not-an-object';
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('calls next when addedConvo is an array', async () => {
|
||||
req.body.addedConvo = [{ agent_id: 'agent_something' }];
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('addedConvo — MULTI_CONVO permission gate', () => {
|
||||
test('returns 403 when user lacks MULTI_CONVO:USE', async () => {
|
||||
req.user.role = 'no-multi-convo';
|
||||
req.body.addedConvo = { agent_id: 'agent_x', endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ message: 'Multi-conversation feature is not enabled' }),
|
||||
);
|
||||
});
|
||||
|
||||
test('returns 403 when user.role is missing', async () => {
|
||||
req.user = { id: testUser._id };
|
||||
req.body.addedConvo = { agent_id: 'agent_x', endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
});
|
||||
|
||||
test('ADMIN bypasses MULTI_CONVO check', async () => {
|
||||
req.user.role = SystemRoles.ADMIN;
|
||||
req.body.addedConvo = { agent_id: 'ephemeral_x', endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('addedConvo — agent_id shape validation', () => {
|
||||
test('calls next when agent_id is ephemeral', async () => {
|
||||
req.body.addedConvo = { agent_id: 'ephemeral_xyz', endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('calls next when agent_id is absent', async () => {
|
||||
req.body.addedConvo = { endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('calls next when agent_id is not a string (object injection)', async () => {
|
||||
req.body.addedConvo = { agent_id: { $gt: '' }, endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('addedConvo — agent resource ACL (IDOR prevention)', () => {
|
||||
let addedAgent;
|
||||
|
||||
beforeEach(async () => {
|
||||
addedAgent = await createAgent({
|
||||
id: `agent_added_${Date.now()}`,
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
});
|
||||
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: addedAgent._id,
|
||||
permBits: 15,
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
});
|
||||
|
||||
test('returns 403 when requester has no ACL for the added agent', async () => {
|
||||
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
message: 'Insufficient permissions to access this agent',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test('returns 404 when added agent does not exist', async () => {
|
||||
req.body.addedConvo = {
|
||||
agent_id: 'agent_nonexistent_999',
|
||||
endpoint: 'agents',
|
||||
model: 'gpt-4',
|
||||
};
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(404);
|
||||
});
|
||||
|
||||
test('proceeds when requester has ACL for the added agent', async () => {
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: addedAgent._id,
|
||||
permBits: 1,
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('denies when ACL permission bits are insufficient', async () => {
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: addedAgent._id,
|
||||
permBits: 1,
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 2 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
});
|
||||
|
||||
test('caches resolved agent on req.resolvedAddedAgent', async () => {
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: addedAgent._id,
|
||||
permBits: 1,
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(req.resolvedAddedAgent).toBeDefined();
|
||||
expect(req.resolvedAddedAgent._id.toString()).toBe(addedAgent._id.toString());
|
||||
});
|
||||
|
||||
test('ADMIN bypasses agent resource ACL for addedConvo', async () => {
|
||||
req.user.role = SystemRoles.ADMIN;
|
||||
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
expect(req.resolvedAddedAgent).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('end-to-end: primary real agent + addedConvo real agent', () => {
|
||||
let primaryAgent, addedAgent;
|
||||
|
||||
beforeEach(async () => {
|
||||
primaryAgent = await createAgent({
|
||||
id: `agent_primary_${Date.now()}`,
|
||||
name: 'Primary Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
});
|
||||
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: primaryAgent._id,
|
||||
permBits: 15,
|
||||
grantedBy: testUser._id,
|
||||
});
|
||||
|
||||
addedAgent = await createAgent({
|
||||
id: `agent_added_${Date.now()}`,
|
||||
name: 'Added Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
});
|
||||
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: addedAgent._id,
|
||||
permBits: 15,
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.body.agent_id = primaryAgent.id;
|
||||
});
|
||||
|
||||
test('both checks pass when user has ACL for both agents', async () => {
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: addedAgent._id,
|
||||
permBits: 1,
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
expect(req.resolvedAddedAgent).toBeDefined();
|
||||
});
|
||||
|
||||
test('primary passes but addedConvo denied → 403', async () => {
|
||||
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
});
|
||||
|
||||
test('primary denied → 403 without reaching addedConvo check', async () => {
|
||||
const foreignAgent = await createAgent({
|
||||
id: `agent_foreign_${Date.now()}`,
|
||||
name: 'Foreign Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
});
|
||||
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: foreignAgent._id,
|
||||
permBits: 15,
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.body.agent_id = foreignAgent.id;
|
||||
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
});
|
||||
});
|
||||
|
||||
describe('ephemeral primary + real addedConvo agent', () => {
|
||||
let addedAgent;
|
||||
|
||||
beforeEach(async () => {
|
||||
addedAgent = await createAgent({
|
||||
id: `agent_added_${Date.now()}`,
|
||||
name: 'Added Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
});
|
||||
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: addedAgent._id,
|
||||
permBits: 15,
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
});
|
||||
|
||||
test('runs full addedConvo ACL check even when primary is ephemeral', async () => {
|
||||
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
});
|
||||
|
||||
test('proceeds when user has ACL for added agent (ephemeral primary)', async () => {
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: addedAgent._id,
|
||||
permBits: 1,
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||
|
||||
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
const { ResourceType } = require('librechat-data-provider');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getAgent } = require('~/models');
|
||||
|
||||
/**
|
||||
* Agent ID resolver function
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ const { ResourceType, PrincipalType, PrincipalModel } = require('librechat-data-
|
|||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { canAccessAgentResource } = require('./canAccessAgentResource');
|
||||
const { User, Role, AclEntry } = require('~/db/models');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
const { createAgent } = require('~/models');
|
||||
|
||||
describe('canAccessAgentResource middleware', () => {
|
||||
let mongoServer;
|
||||
|
|
@ -373,7 +373,7 @@ describe('canAccessAgentResource middleware', () => {
|
|||
jest.clearAllMocks();
|
||||
|
||||
// Update the agent
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
const { updateAgent } = require('~/models');
|
||||
await updateAgent({ id: agentId }, { description: 'Updated description' });
|
||||
|
||||
// Test edit access
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { ResourceType, PrincipalType, PrincipalModel } = require('librechat-data-provider');
|
||||
const { SystemCapabilities } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { canAccessMCPServerResource } = require('./canAccessMCPServerResource');
|
||||
const { User, Role, AclEntry } = require('~/db/models');
|
||||
const { User, Role, AclEntry, SystemGrant } = require('~/db/models');
|
||||
const { createMCPServer } = require('~/models');
|
||||
|
||||
describe('canAccessMCPServerResource middleware', () => {
|
||||
|
|
@ -511,7 +512,7 @@ describe('canAccessMCPServerResource middleware', () => {
|
|||
});
|
||||
});
|
||||
|
||||
test('should allow admin users to bypass permission checks', async () => {
|
||||
test('should allow users with MANAGE_MCP_SERVERS capability to bypass permission checks', async () => {
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
|
||||
// Create an MCP server owned by another user
|
||||
|
|
@ -531,6 +532,14 @@ describe('canAccessMCPServerResource middleware', () => {
|
|||
author: otherUser._id,
|
||||
});
|
||||
|
||||
// Seed MANAGE_MCP_SERVERS capability for the ADMIN role
|
||||
await SystemGrant.create({
|
||||
principalType: PrincipalType.ROLE,
|
||||
principalId: SystemRoles.ADMIN,
|
||||
capability: SystemCapabilities.MANAGE_MCP_SERVERS,
|
||||
grantedAt: new Date(),
|
||||
});
|
||||
|
||||
// Set user as admin
|
||||
req.user = { id: testUser._id, role: SystemRoles.ADMIN };
|
||||
req.params.serverName = mcpServer.serverName;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const { ResourceType } = require('librechat-data-provider');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { getPromptGroup } = require('~/models/Prompt');
|
||||
const { getPromptGroup } = require('~/models');
|
||||
|
||||
/**
|
||||
* PromptGroup ID resolver function
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const { ResourceType } = require('librechat-data-provider');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { getPrompt } = require('~/models/Prompt');
|
||||
const { getPrompt } = require('~/models');
|
||||
|
||||
/**
|
||||
* Prompt to PromptGroup ID resolver function
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { logger, ResourceCapabilityMap } = require('@librechat/data-schemas');
|
||||
const { hasCapability } = require('~/server/middleware/roles/capabilities');
|
||||
const { checkPermission } = require('~/server/services/PermissionService');
|
||||
|
||||
/**
|
||||
|
|
@ -71,8 +71,17 @@ const canAccessResource = (options) => {
|
|||
message: 'Authentication required',
|
||||
});
|
||||
}
|
||||
// if system admin let through
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
const cap = ResourceCapabilityMap[resourceType];
|
||||
let hasCap = false;
|
||||
try {
|
||||
hasCap = cap != null && (await hasCapability(req.user, cap));
|
||||
} catch (err) {
|
||||
logger.warn(`[canAccessResource] capability check failed, denying bypass: ${err.message}`);
|
||||
}
|
||||
if (hasCap) {
|
||||
logger.debug(
|
||||
`[canAccessResource] ${cap} bypass for user ${req.user.id} on ${resourceType} ${rawResourceId}`,
|
||||
);
|
||||
return next();
|
||||
}
|
||||
const userId = req.user.id;
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { PermissionBits, hasPermissions, ResourceType } = require('librechat-data-provider');
|
||||
const { getEffectivePermissions } = require('~/server/services/PermissionService');
|
||||
const { getAgents } = require('~/models/Agent');
|
||||
const { getFiles } = require('~/models');
|
||||
const { getAgents, getFiles } = require('~/models');
|
||||
|
||||
/**
|
||||
* Checks if user has access to a file through agent permissions
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@ const { ResourceType, PrincipalType, PrincipalModel } = require('librechat-data-
|
|||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { fileAccess } = require('./fileAccess');
|
||||
const { User, Role, AclEntry } = require('~/db/models');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
const { createFile } = require('~/models');
|
||||
const { createAgent, createFile } = require('~/models');
|
||||
|
||||
describe('fileAccess middleware', () => {
|
||||
let mongoServer;
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue