From 43ea6757d5f8115137dd6d5bf0fe2cd83b5e1dea Mon Sep 17 00:00:00 2001 From: Davide Garberi Date: Tue, 27 Jan 2026 05:59:11 +0100 Subject: [PATCH] :art: Implement HTTPS network serving (#16912) * Add use TLS for network serving configuration option * kernel: Implement TLS certificate generation * kernel: server: Use https for fixed port proxy when needed * Allow exporting the CA Certificate file * Implement import and export of CA Certs --- app/appearance/langs/en_US.json | 9 + app/appearance/langs/zh_CN.json | 9 + app/src/config/about.ts | 99 ++++++++- app/src/mobile/settings/about.ts | 90 ++++++++- app/src/types/config.d.ts | 4 + kernel/api/router.go | 4 + kernel/api/system.go | 167 +++++++++++++++ kernel/conf/system.go | 5 +- kernel/server/proxy/fixedport.go | 23 ++- kernel/server/serve.go | 22 +- kernel/util/cert.go | 337 +++++++++++++++++++++++++++++++ 11 files changed, 759 insertions(+), 10 deletions(-) create mode 100644 kernel/util/cert.go diff --git a/app/appearance/langs/en_US.json b/app/appearance/langs/en_US.json index 56d215d10..956121ea3 100644 --- a/app/appearance/langs/en_US.json +++ b/app/appearance/langs/en_US.json @@ -1297,6 +1297,15 @@ "about8": "When enabled, the application will be automatically locked when locking the system screen", "about11": "Network serving", "about12": "When enabled, other devices in the same LAN will be allowed to access. The application will be closed automatically after modification, please restart manually", + "networkServeTLS": "Enable HTTPS", + "networkServeTLSTip": "When enabled, network connections will be encrypted with TLS using auto-generated self-signed certificates. Browsers will show a security warning that must be accepted. Requires restart", + "exportCACert": "Export CA Certificate", + "exportCACertTip": "Export the CA certificate (ca.crt) file. Install this certificate on client devices to trust the self-signed HTTPS connection", + "exportCABundle": "Export CA Bundle", + "exportCABundleTip": "Export the CA certificate and private key for sharing with other SiYuan devices. All devices using the same CA will be trusted by clients that import it", + "importCABundle": "Import CA Bundle", + "importCABundleTip": "Import a CA bundle from another SiYuan device. After importing, this device will use the same CA, allowing clients to trust certificates from all devices", + "importCABundleSuccess": "CA bundle imported successfully. Please restart the application to apply changes", "about13": "API token", "about14": "The token needs to be authenticated when calling the API
HTTP request header Authorization: token ${token}", "about17": "Do not enable proxy when set to Direct connection", diff --git a/app/appearance/langs/zh_CN.json b/app/appearance/langs/zh_CN.json index 2e9b6c07c..7582f3240 100644 --- a/app/appearance/langs/zh_CN.json +++ b/app/appearance/langs/zh_CN.json @@ -1297,6 +1297,15 @@ "about8": "启用后将会在系统锁屏时自动锁定应用", "about11": "网络伺服", "about12": "启用后将允许同一局域网内的其他设备进行访问。修改后会自动关闭应用,请手动重启", + "networkServeTLS": "启用 HTTPS", + "networkServeTLSTip": "启用后网络连接将使用自动生成的自签名证书进行 TLS 加密。浏览器会显示安全警告,需要手动接受。启用后需重启应用", + "exportCACert": "导出 CA 证书", + "exportCACertTip": "导出 CA 证书(ca.crt)文件。将此证书安装到客户端设备以信任自签名 HTTPS 连接", + "exportCABundle": "导出 CA 证书包", + "exportCABundleTip": "导出 CA 证书和私钥以便与其他思源设备共享。使用相同 CA 的所有设备将被导入该证书的客户端信任", + "importCABundle": "导入 CA 证书包", + "importCABundleTip": "从另一台思源设备导入 CA 证书包。导入后,此设备将使用相同的 CA,允许客户端信任所有设备的证书", + "importCABundleSuccess": "CA 证书包导入成功。请重启应用以应用更改", "about13": "API token", "about14": "调用 API 时需要通过该 token 进行鉴权
HTTP 请求标头 Authorization: token ${token}", "about17": "设置为 直接连接 时不启用代理", diff --git a/app/src/config/about.ts b/app/src/config/about.ts index ca3ff2058..a486d35b8 100644 --- a/app/src/config/about.ts +++ b/app/src/config/about.ts @@ -64,6 +64,44 @@ export const about = {
+ +
+
+ ${window.siyuan.languages.exportCACert} +
${window.siyuan.languages.exportCACertTip}
+
+
+ +
+
+
+ ${window.siyuan.languages.exportCABundle} +
${window.siyuan.languages.exportCABundleTip}
+
+
+ +
+
+
+ ${window.siyuan.languages.importCABundle} +
${window.siyuan.languages.importCABundleTip}
+
+
+ +
@@ -102,7 +140,7 @@ export const about = {
${window.siyuan.languages.about18}
-
@@ -379,7 +417,12 @@ ${checkUpdateHTML} }); }); const networkServeElement = about.element.querySelector("#networkServe") as HTMLInputElement; + const networkServeTLSElement = about.element.querySelector("#networkServeTLS") as HTMLInputElement; networkServeElement.addEventListener("change", () => { + networkServeTLSElement.disabled = !networkServeElement.checked; + if (!networkServeElement.checked) { + networkServeTLSElement.checked = false; + } fetchPost("/api/system/setNetworkServe", {networkServe: networkServeElement.checked}, () => { exportLayout({ errorExit: true, @@ -387,6 +430,60 @@ ${checkUpdateHTML} }); }); }); + networkServeTLSElement.addEventListener("change", () => { + const exportCACertSection = about.element.querySelector("#exportCACertSection"); + const exportCABundleSection = about.element.querySelector("#exportCABundleSection"); + const importCABundleSection = about.element.querySelector("#importCABundleSection"); + if (exportCACertSection && exportCABundleSection && importCABundleSection) { + if (networkServeTLSElement.checked) { + exportCACertSection.classList.remove("fn__none"); + exportCABundleSection.classList.remove("fn__none"); + importCABundleSection.classList.remove("fn__none"); + } else { + exportCACertSection.classList.add("fn__none"); + exportCABundleSection.classList.add("fn__none"); + importCABundleSection.classList.add("fn__none"); + } + } + fetchPost("/api/system/setNetworkServeTLS", {networkServeTLS: networkServeTLSElement.checked}, () => { + exportLayout({ + errorExit: true, + cb: exitSiYuan + }); + }); + }); + about.element.querySelector("#exportCACert")?.addEventListener("click", () => { + fetchPost("/api/system/exportTLSCACert", {}, (response) => { + openByMobile(response.data.path); + }); + }); + about.element.querySelector("#exportCABundle")?.addEventListener("click", () => { + fetchPost("/api/system/exportTLSCABundle", {}, (response) => { + openByMobile(response.data.path); + }); + }); + about.element.querySelector("#importCABundle")?.addEventListener("click", () => { + const input = document.createElement("input"); + input.type = "file"; + input.accept = ".zip"; + input.onchange = () => { + if (input.files && input.files[0]) { + const formData = new FormData(); + formData.append("file", input.files[0]); + fetch("/api/system/importTLSCABundle", { + method: "POST", + body: formData, + }).then(res => res.json()).then((response) => { + if (response.code === 0) { + showMessage(window.siyuan.languages.importCABundleSuccess); + } else { + showMessage(response.msg, 6000, "error"); + } + }); + } + }; + input.click(); + }); const lockScreenModeElement = about.element.querySelector("#lockScreenMode") as HTMLInputElement; lockScreenModeElement.addEventListener("change", () => { fetchPost("/api/system/setFollowSystemLockScreen", {lockScreenMode: lockScreenModeElement.checked ? 1 : 0}, () => { diff --git a/app/src/mobile/settings/about.ts b/app/src/mobile/settings/about.ts index e5f86e0a3..58827339c 100644 --- a/app/src/mobile/settings/about.ts +++ b/app/src/mobile/settings/about.ts @@ -24,10 +24,42 @@ export const initAbout = () => {
+ +
+ ${window.siyuan.languages.exportCACert} +
+ +
${window.siyuan.languages.exportCACertTip}
+
+
+ ${window.siyuan.languages.exportCABundle} +
+ +
${window.siyuan.languages.exportCABundleTip}
+
+
+ ${window.siyuan.languages.importCABundle} +
+ +
${window.siyuan.languages.importCABundleTip}
+
${window.siyuan.languages.about2}
- + ${window.siyuan.languages.about4}
${window.siyuan.languages.about3.replace("${port}", location.port)}
@@ -451,11 +483,67 @@ export const initAbout = () => { }); }); const networkServeElement = modelMainElement.querySelector("#networkServe") as HTMLInputElement; + const networkServeTLSElement = modelMainElement.querySelector("#networkServeTLS") as HTMLInputElement; networkServeElement.addEventListener("change", () => { + networkServeTLSElement.disabled = !networkServeElement.checked; + if (!networkServeElement.checked) { + networkServeTLSElement.checked = false; + } fetchPost("/api/system/setNetworkServe", {networkServe: networkServeElement.checked}, () => { exitSiYuan(); }); }); + networkServeTLSElement.addEventListener("change", () => { + const exportCACertSection = modelMainElement.querySelector("#exportCACertSection"); + const exportCABundleSection = modelMainElement.querySelector("#exportCABundleSection"); + const importCABundleSection = modelMainElement.querySelector("#importCABundleSection"); + if (exportCACertSection && exportCABundleSection && importCABundleSection) { + if (networkServeTLSElement.checked) { + exportCACertSection.classList.remove("fn__none"); + exportCABundleSection.classList.remove("fn__none"); + importCABundleSection.classList.remove("fn__none"); + } else { + exportCACertSection.classList.add("fn__none"); + exportCABundleSection.classList.add("fn__none"); + importCABundleSection.classList.add("fn__none"); + } + } + fetchPost("/api/system/setNetworkServeTLS", {networkServeTLS: networkServeTLSElement.checked}, () => { + exitSiYuan(); + }); + }); + modelMainElement.querySelector("#exportCACert")?.addEventListener("click", () => { + fetchPost("/api/system/exportTLSCACert", {}, (response) => { + openByMobile(response.data.path); + }); + }); + modelMainElement.querySelector("#exportCABundle")?.addEventListener("click", () => { + fetchPost("/api/system/exportTLSCABundle", {}, (response) => { + openByMobile(response.data.path); + }); + }); + modelMainElement.querySelector("#importCABundle")?.addEventListener("click", () => { + const input = document.createElement("input"); + input.type = "file"; + input.accept = ".zip"; + input.onchange = () => { + if (input.files && input.files[0]) { + const formData = new FormData(); + formData.append("file", input.files[0]); + fetch("/api/system/importTLSCABundle", { + method: "POST", + body: formData, + }).then(res => res.json()).then((response) => { + if (response.code === 0) { + showMessage(window.siyuan.languages.importCABundleSuccess); + } else { + showMessage(response.msg, 6000, "error"); + } + }); + } + }; + input.click(); + }); const tokenElement = modelMainElement.querySelector("#token") as HTMLInputElement; tokenElement.addEventListener("change", () => { fetchPost("/api/system/setAPIToken", {token: tokenElement.value}, () => { diff --git a/app/src/types/config.d.ts b/app/src/types/config.d.ts index efd08c075..cd90be0a1 100644 --- a/app/src/types/config.d.ts +++ b/app/src/types/config.d.ts @@ -1623,6 +1623,10 @@ declare namespace Config { * Whether to enable network serve (whether to allow connections from other devices) */ networkServe: boolean; + /** + * Whether to enable HTTPS for network serve (TLS encryption) + */ + networkServeTLS: boolean; /** * The operating system name determined at compile time (obtained using the command `go tool * dist list`) diff --git a/kernel/api/router.go b/kernel/api/router.go index 2b03ea394..b5d6ebac2 100644 --- a/kernel/api/router.go +++ b/kernel/api/router.go @@ -42,6 +42,10 @@ func ServeAPI(ginServer *gin.Engine) { ginServer.Handle("POST", "/api/system/setAccessAuthCode", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setAccessAuthCode) ginServer.Handle("POST", "/api/system/setFollowSystemLockScreen", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setFollowSystemLockScreen) ginServer.Handle("POST", "/api/system/setNetworkServe", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setNetworkServe) + ginServer.Handle("POST", "/api/system/setNetworkServeTLS", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setNetworkServeTLS) + ginServer.Handle("POST", "/api/system/exportTLSCACert", model.CheckAuth, model.CheckAdminRole, exportTLSCACert) + ginServer.Handle("POST", "/api/system/exportTLSCABundle", model.CheckAuth, model.CheckAdminRole, exportTLSCABundle) + ginServer.Handle("POST", "/api/system/importTLSCABundle", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, importTLSCABundle) ginServer.Handle("POST", "/api/system/setAutoLaunch", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setAutoLaunch) ginServer.Handle("POST", "/api/system/setDownloadInstallPkg", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setDownloadInstallPkg) ginServer.Handle("POST", "/api/system/setNetworkProxy", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setNetworkProxy) diff --git a/kernel/api/system.go b/kernel/api/system.go index a6f097e27..314ed0e82 100644 --- a/kernel/api/system.go +++ b/kernel/api/system.go @@ -720,6 +720,173 @@ func setNetworkServe(c *gin.Context) { time.Sleep(time.Second * 3) } +func setNetworkServeTLS(c *gin.Context) { + ret := gulu.Ret.NewResult() + defer c.JSON(http.StatusOK, ret) + + arg, ok := util.JsonArg(c, ret) + if !ok { + return + } + + networkServeTLS := arg["networkServeTLS"].(bool) + model.Conf.System.NetworkServeTLS = networkServeTLS + model.Conf.Save() + + util.PushMsg(model.Conf.Language(42), 1000*15) + time.Sleep(time.Second * 3) +} + +func exportTLSCACert(c *gin.Context) { + ret := gulu.Ret.NewResult() + defer c.JSON(http.StatusOK, ret) + + caCertPath := filepath.Join(util.ConfDir, util.TLSCACertFilename) + if !gulu.File.IsExist(caCertPath) { + ret.Code = -1 + ret.Msg = "CA certificate not found" + return + } + + tmpDir := filepath.Join(util.TempDir, "export") + if err := os.MkdirAll(tmpDir, 0755); err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + + exportPath := filepath.Join(tmpDir, util.TLSCACertFilename) + if err := gulu.File.CopyFile(caCertPath, exportPath); err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + + ret.Data = map[string]interface{}{ + "path": "/export/" + util.TLSCACertFilename, + } +} + +func exportTLSCABundle(c *gin.Context) { + ret := gulu.Ret.NewResult() + defer c.JSON(http.StatusOK, ret) + + caCertPath := filepath.Join(util.ConfDir, util.TLSCACertFilename) + caKeyPath := filepath.Join(util.ConfDir, util.TLSCAKeyFilename) + + if !gulu.File.IsExist(caCertPath) || !gulu.File.IsExist(caKeyPath) { + ret.Code = -1 + ret.Msg = "CA certificate not found, please enable TLS first" + return + } + + tmpDir := filepath.Join(util.TempDir, "export", "ca-bundle") + os.RemoveAll(tmpDir) + if err := os.MkdirAll(tmpDir, 0755); err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + defer os.RemoveAll(tmpDir) + + if err := gulu.File.CopyFile(caCertPath, filepath.Join(tmpDir, util.TLSCACertFilename)); err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + if err := gulu.File.CopyFile(caKeyPath, filepath.Join(tmpDir, util.TLSCAKeyFilename)); err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + + zipPath := filepath.Join(util.TempDir, "export", "ca-bundle.zip") + zipFile, err := gulu.Zip.Create(zipPath) + if err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + + if err := zipFile.AddDirectory("", tmpDir); err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + + if err := zipFile.Close(); err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + + ret.Data = map[string]interface{}{ + "path": "/export/ca-bundle.zip", + } +} + +func importTLSCABundle(c *gin.Context) { + ret := gulu.Ret.NewResult() + defer c.JSON(http.StatusOK, ret) + + file, err := c.FormFile("file") + if err != nil { + ret.Code = -1 + ret.Msg = "file is required: " + err.Error() + return + } + + tmpDir := filepath.Join(util.TempDir, "import") + if err := os.MkdirAll(tmpDir, 0755); err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + + tmpZipPath := filepath.Join(tmpDir, "ca-bundle.zip") + if err := c.SaveUploadedFile(file, tmpZipPath); err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + defer os.Remove(tmpZipPath) + + extractDir := filepath.Join(tmpDir, "ca-bundle") + os.RemoveAll(extractDir) + if err := gulu.Zip.Unzip(tmpZipPath, extractDir); err != nil { + ret.Code = -1 + ret.Msg = "failed to extract zip file: " + err.Error() + return + } + defer os.RemoveAll(extractDir) + + caCertPath := filepath.Join(extractDir, util.TLSCACertFilename) + caCertPEM, err := os.ReadFile(caCertPath) + if err != nil { + ret.Code = -1 + ret.Msg = "ca.crt not found in zip file" + return + } + + caKeyPath := filepath.Join(extractDir, util.TLSCAKeyFilename) + caKeyPEM, err := os.ReadFile(caKeyPath) + if err != nil { + ret.Code = -1 + ret.Msg = "ca.key not found in zip file" + return + } + + if err := util.ImportCABundle(string(caCertPEM), string(caKeyPEM)); err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + + ret.Data = map[string]interface{}{ + "msg": "CA bundle imported successfully. Please restart to apply changes.", + } +} + func setAutoLaunch(c *gin.Context) { ret := gulu.Ret.NewResult() defer c.JSON(http.StatusOK, ret) diff --git a/kernel/conf/system.go b/kernel/conf/system.go index 58b0475a2..08069ac4e 100644 --- a/kernel/conf/system.go +++ b/kernel/conf/system.go @@ -36,8 +36,9 @@ type System struct { ConfDir string `json:"confDir"` DataDir string `json:"dataDir"` - NetworkServe bool `json:"networkServe"` // 是否开启网络伺服 - NetworkProxy *NetworkProxy `json:"networkProxy"` + NetworkServe bool `json:"networkServe"` // 是否开启网络伺服 + NetworkServeTLS bool `json:"networkServeTLS"` // 是否开启 HTTPS 网络伺服 + NetworkProxy *NetworkProxy `json:"networkProxy"` DownloadInstallPkg bool `json:"downloadInstallPkg"` AutoLaunch2 int `json:"autoLaunch2"` // 0:不自动启动,1:自动启动,2:自动启动+隐藏主窗口 diff --git a/kernel/server/proxy/fixedport.go b/kernel/server/proxy/fixedport.go index dface6bc1..254d459c1 100644 --- a/kernel/server/proxy/fixedport.go +++ b/kernel/server/proxy/fixedport.go @@ -17,6 +17,7 @@ package proxy import ( + "crypto/tls" "net/http" "net/http/httputil" @@ -24,7 +25,7 @@ import ( "github.com/siyuan-note/siyuan/kernel/util" ) -func InitFixedPortService(host string) { +func InitFixedPortService(host string, useTLS bool, certPath, keyPath string) { if util.FixedPort != util.ServerPort { if util.IsPortOpen(util.FixedPort) { return @@ -32,9 +33,23 @@ func InitFixedPortService(host string) { // 启动一个固定 6806 端口的反向代理服务器,这样浏览器扩展才能直接使用 127.0.0.1:6806,不用配置端口 proxy := httputil.NewSingleHostReverseProxy(util.ServerURL) - logging.LogInfof("fixed port service [%s:%s] is running", host, util.FixedPort) - if proxyErr := http.ListenAndServe(host+":"+util.FixedPort, proxy); nil != proxyErr { - logging.LogWarnf("boot fixed port service [%s] failed: %s", util.ServerURL, proxyErr) + + if useTLS { + proxy.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + } + + if useTLS { + logging.LogInfof("fixed port service [%s:%s] is running with TLS", host, util.FixedPort) + if proxyErr := http.ListenAndServeTLS(host+":"+util.FixedPort, certPath, keyPath, proxy); nil != proxyErr { + logging.LogWarnf("boot fixed port service [%s] failed: %s", util.ServerURL, proxyErr) + } + } else { + logging.LogInfof("fixed port service [%s:%s] is running", host, util.FixedPort) + if proxyErr := http.ListenAndServe(host+":"+util.FixedPort, proxy); nil != proxyErr { + logging.LogWarnf("boot fixed port service [%s] failed: %s", util.ServerURL, proxyErr) + } } logging.LogInfof("fixed port service [%s:%s] is stopped", host, util.FixedPort) } diff --git a/kernel/server/serve.go b/kernel/server/serve.go index 40fb560af..702215dba 100644 --- a/kernel/server/serve.go +++ b/kernel/server/serve.go @@ -210,14 +210,32 @@ func Serve(fastMode bool, cookieKey string) { if !fastMode { rewritePortJSON(pid, port) } - logging.LogInfof("kernel [pid=%s] http server [%s] is booting", pid, host+":"+port) + + // Prepare TLS if enabled + var certPath, keyPath string + useTLS := model.Conf.System.NetworkServeTLS && model.Conf.System.NetworkServe + if useTLS { + // Ensure TLS certificates exist (proxy will use them directly) + var tlsErr error + certPath, keyPath, tlsErr = util.GetOrCreateTLSCert() + if tlsErr != nil { + logging.LogErrorf("failed to get TLS certificates: %s", tlsErr) + if !fastMode { + os.Exit(logging.ExitCodeUnavailablePort) + } + return + } + logging.LogInfof("kernel [pid=%s] http server [%s] is booting (TLS will be enabled on fixed port proxy)", pid, host+":"+port) + } else { + logging.LogInfof("kernel [pid=%s] http server [%s] is booting", pid, host+":"+port) + } util.HttpServing = true go util.HookUILoaded() go func() { time.Sleep(1 * time.Second) - go proxy.InitFixedPortService(host) + go proxy.InitFixedPortService(host, useTLS, certPath, keyPath) go proxy.InitPublishService() // 反代服务器启动失败不影响核心服务器启动 }() diff --git a/kernel/util/cert.go b/kernel/util/cert.go new file mode 100644 index 000000000..76d991c58 --- /dev/null +++ b/kernel/util/cert.go @@ -0,0 +1,337 @@ +// SiYuan - Refactor your thinking +// Copyright (c) 2020-present, b3log.org +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package util + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "os" + "path/filepath" + "time" + + "github.com/88250/gulu" + "github.com/siyuan-note/logging" +) + +const ( + TLSCACertFilename = "ca.crt" + TLSCAKeyFilename = "ca.key" + TLSCertFilename = "cert.pem" + TLSKeyFilename = "key.pem" +) + +// Returns paths to existing TLS certificates or generates new ones signed by a local CA. +// Certificates are stored in the conf directory of the workspace. +func GetOrCreateTLSCert() (certPath, keyPath string, err error) { + certPath = filepath.Join(ConfDir, TLSCertFilename) + keyPath = filepath.Join(ConfDir, TLSKeyFilename) + caCertPath := filepath.Join(ConfDir, TLSCACertFilename) + caKeyPath := filepath.Join(ConfDir, TLSCAKeyFilename) + + if !gulu.File.IsExist(caCertPath) || !gulu.File.IsExist(caKeyPath) { + logging.LogInfof("generating local CA for TLS...") + if err = generateCACert(caCertPath, caKeyPath); err != nil { + logging.LogErrorf("failed to generate CA certificates: %s", err) + return "", "", err + } + } + + if gulu.File.IsExist(certPath) && gulu.File.IsExist(keyPath) { + if validateCert(certPath) { + logging.LogInfof("using existing TLS certificates from [%s]", ConfDir) + return certPath, keyPath, nil + } + logging.LogInfof("existing TLS certificates are invalid or expired, regenerating...") + } + + caCert, caKey, err := loadCA(caCertPath, caKeyPath) + if err != nil { + logging.LogErrorf("failed to load CA certificates: %s", err) + return "", "", err + } + + logging.LogInfof("generating TLS server certificates signed by local CA...") + if err = generateServerCert(certPath, keyPath, caCert, caKey); err != nil { + logging.LogErrorf("failed to generate TLS certificates: %s", err) + return "", "", err + } + + logging.LogInfof("generated TLS certificates at [%s]", ConfDir) + return certPath, keyPath, nil +} + +// Checks if the certificate file exists, is not expired, and contains all current IP addresses +func validateCert(certPath string) bool { + certPEM, err := os.ReadFile(certPath) + if err != nil { + return false + } + + block, _ := pem.Decode(certPEM) + if block == nil { + return false + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return false + } + + // Check if certificate is still valid, with 7 day buffer + if !time.Now().Add(7 * 24 * time.Hour).Before(cert.NotAfter) { + return false + } + + // Check if certificate contains all current IP addresses + currentIPs := GetServerAddrs() + certIPMap := make(map[string]bool) + for _, ip := range cert.IPAddresses { + certIPMap[ip.String()] = true + } + + for _, ipStr := range currentIPs { + ipStr = trimIPv6Brackets(ipStr) + ip := net.ParseIP(ipStr) + if ip == nil { + continue + } + + if !certIPMap[ip.String()] { + logging.LogInfof("certificate missing current IP address [%s], will regenerate", ip.String()) + return false + } + } + + return true +} + +// Creates a new self-signed CA certificate +func generateCACert(certPath, keyPath string) error { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return err + } + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return err + } + + notBefore := time.Now() + notAfter := notBefore.Add(10 * 365 * 24 * time.Hour) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"SiYuan"}, + CommonName: "SiYuan Local CA", + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return err + } + + return writeCertAndKey(certPath, keyPath, certDER, privateKey) +} + +// Creates a new server certificate signed by the CA +func generateServerCert(certPath, keyPath string, caCert *x509.Certificate, caKey any) error { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return err + } + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return err + } + + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) + + ipAddresses := []net.IP{ + net.ParseIP("127.0.0.1"), + net.IPv6loopback, + } + + localIPs := GetServerAddrs() + for _, ipStr := range localIPs { + ipStr = trimIPv6Brackets(ipStr) + if ip := net.ParseIP(ipStr); ip != nil { + ipAddresses = append(ipAddresses, ip) + } + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"SiYuan"}, + CommonName: "SiYuan Local Server", + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + IPAddresses: ipAddresses, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, caCert, &privateKey.PublicKey, caKey) + if err != nil { + return err + } + + return writeCertAndKey(certPath, keyPath, certDER, privateKey) +} + +// Loads the CA certificate and private key from files +func loadCA(certPath, keyPath string) (*x509.Certificate, any, error) { + certPEM, err := os.ReadFile(certPath) + if err != nil { + return nil, nil, err + } + + block, _ := pem.Decode(certPEM) + if block == nil { + return nil, nil, fmt.Errorf("failed to decode CA certificate PEM") + } + + caCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, nil, err + } + + keyPEM, err := os.ReadFile(keyPath) + if err != nil { + return nil, nil, err + } + + block, _ = pem.Decode(keyPEM) + if block == nil { + return nil, nil, fmt.Errorf("failed to decode CA key PEM") + } + + caKey, err := x509.ParseECPrivateKey(block.Bytes) + if err != nil { + return nil, nil, err + } + + return caCert, caKey, nil +} + +func writeCertAndKey(certPath, keyPath string, certDER []byte, privateKey *ecdsa.PrivateKey) error { + certFile, err := os.Create(certPath) + if err != nil { + return err + } + defer certFile.Close() + + if err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + return err + } + + keyFile, err := os.Create(keyPath) + if err != nil { + return err + } + defer keyFile.Close() + + keyDER, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return err + } + + if err = pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil { + return err + } + + return nil +} + +// Imports a CA certificate and private key from PEM-encoded strings. +func ImportCABundle(caCertPEM, caKeyPEM string) error { + certBlock, _ := pem.Decode([]byte(caCertPEM)) + if certBlock == nil { + return fmt.Errorf("failed to decode CA certificate PEM") + } + + caCert, err := x509.ParseCertificate(certBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse CA certificate: %w", err) + } + + if !caCert.IsCA { + return fmt.Errorf("the provided certificate is not a CA certificate") + } + + keyBlock, _ := pem.Decode([]byte(caKeyPEM)) + if keyBlock == nil { + return fmt.Errorf("failed to decode CA private key PEM") + } + + _, err = x509.ParseECPrivateKey(keyBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse CA private key: %w", err) + } + + caCertPath := filepath.Join(ConfDir, TLSCACertFilename) + caKeyPath := filepath.Join(ConfDir, TLSCAKeyFilename) + + if err := os.WriteFile(caCertPath, []byte(caCertPEM), 0644); err != nil { + return fmt.Errorf("failed to write CA certificate: %w", err) + } + + if err := os.WriteFile(caKeyPath, []byte(caKeyPEM), 0600); err != nil { + return fmt.Errorf("failed to write CA private key: %w", err) + } + + certPath := filepath.Join(ConfDir, TLSCertFilename) + keyPath := filepath.Join(ConfDir, TLSKeyFilename) + + if gulu.File.IsExist(certPath) { + os.Remove(certPath) + } + if gulu.File.IsExist(keyPath) { + os.Remove(keyPath) + } + + logging.LogInfof("imported CA bundle, server certificate will be regenerated on next TLS initialization") + return nil +} + +// trimIPv6Brackets removes brackets from IPv6 address strings like "[::1]" +func trimIPv6Brackets(ip string) string { + if len(ip) > 2 && ip[0] == '[' && ip[len(ip)-1] == ']' { + return ip[1 : len(ip)-1] + } + return ip +}