From ffd4ceb0d96b6e92edb4f1a7780a875f701bfe80 Mon Sep 17 00:00:00 2001 From: Daniel <845765@qq.com> Date: Mon, 29 Dec 2025 19:35:31 +0800 Subject: [PATCH] :art: Unable to switch the publish service between multiple workspaces https://github.com/siyuan-note/siyuan/issues/16587 Signed-off-by: Daniel <845765@qq.com> --- kernel/server/proxy/publish.go | 90 ++++++++++++++++++---------------- kernel/util/net.go | 6 +-- 2 files changed, 51 insertions(+), 45 deletions(-) diff --git a/kernel/server/proxy/publish.go b/kernel/server/proxy/publish.go index 8f244f651..c10a06905 100644 --- a/kernel/server/proxy/publish.go +++ b/kernel/server/proxy/publish.go @@ -17,10 +17,12 @@ package proxy import ( + "context" "fmt" "net" "net/http" "net/http/httputil" + "time" "github.com/siyuan-note/logging" "github.com/siyuan-note/siyuan/kernel/model" @@ -34,6 +36,7 @@ var ( Port = "0" listener net.Listener + server *http.Server transport = PublishServiceTransport{} ) @@ -42,7 +45,6 @@ func InitPublishService() (uint16, error) { if listener != nil { if !model.Conf.Publish.Enable { - // 关闭发布服务 closePublishListener() return 0, nil } @@ -50,12 +52,7 @@ func InitPublishService() (uint16, error) { if port, err := util.ParsePort(Port); err != nil { return 0, err } else if port != model.Conf.Publish.Port { - // 关闭原端口的发布服务 - if err = closePublishListener(); err != nil { - return 0, err - } - - // 重新启动新端口的发布服务 + closePublishListener() initPublishService() } } else { @@ -69,15 +66,14 @@ func InitPublishService() (uint16, error) { return util.ParsePort(Port) } -func initPublishService() (err error) { - if err = initPublishListener(); err == nil { +func initPublishService() { + if err := initPublishListener(); err == nil { go startPublishReverseProxyService() } return } func initPublishListener() (err error) { - // Start new listener listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", Host, model.Conf.Publish.Port)) if err != nil { logging.LogErrorf("start listener failed: %s", err) @@ -92,24 +88,35 @@ func initPublishListener() (err error) { return } -func closePublishListener() (err error) { - listener_ := listener - listener = nil - if err = listener_.Close(); err != nil { - logging.LogErrorf("close listener %s failed: %s", listener_.Addr().String(), err) - listener = listener_ +func closePublishListener() { + if server == nil { + return } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := server.Shutdown(ctx); err != nil { + logging.LogErrorf("shutdown server failed: %s", err) + } + cancel() + + if err := server.Close(); err != nil { + logging.LogErrorf("close server failed: %s", err) + } + server, listener = nil, nil return } func startPublishReverseProxyService() { logging.LogInfof("publish service [%s:%s] is running", Host, Port) - proxy := &httputil.ReverseProxy{ - Rewrite: rewrite, - Transport: transport, + server = &http.Server{ + Handler: &httputil.ReverseProxy{ + Rewrite: rewrite, + Transport: transport, + }, } - if err := http.Serve(listener, proxy); err != nil { + + if err := server.Serve(listener); err != nil { if listener != nil { logging.LogErrorf("boot publish service failed: %s", err) } @@ -138,10 +145,10 @@ func (PublishServiceTransport) RoundTrip(request *http.Request) (response *http. request.Header.Set(model.XAuthTokenKey, account.Token) response, err = http.DefaultTransport.RoundTrip(request) return - } else { - // Invalid account, remove session - model.DeleteSession(sessionID) } + + // Invalid account, remove session + model.DeleteSession(sessionID) } } @@ -167,27 +174,26 @@ func (PublishServiceTransport) RoundTrip(request *http.Request) (response *http. Close: false, ContentLength: -1, }, nil - } else { - // set session cookie - sessionID := model.GetNewSessionID() - cookie := &http.Cookie{ - Name: model.SessionIdCookieName, - Value: sessionID, - Path: "/", - HttpOnly: true, - } - model.AddSession(sessionID, username) - - // set JWT - request.Header.Set(model.XAuthTokenKey, account.Token) - response, err = http.DefaultTransport.RoundTrip(request) - - response.Header.Add("Set-Cookie", cookie.String()) - return } - } else { - request.Header.Set(model.XAuthTokenKey, model.GetBasicAuthAccount("").Token) + + // set session cookie + sessionID := model.GetNewSessionID() + cookie := &http.Cookie{ + Name: model.SessionIdCookieName, + Value: sessionID, + Path: "/", + HttpOnly: true, + } + model.AddSession(sessionID, username) + + // set JWT + request.Header.Set(model.XAuthTokenKey, account.Token) response, err = http.DefaultTransport.RoundTrip(request) + response.Header.Add("Set-Cookie", cookie.String()) return } + + request.Header.Set(model.XAuthTokenKey, model.GetBasicAuthAccount("").Token) + response, err = http.DefaultTransport.RoundTrip(request) + return } diff --git a/kernel/util/net.go b/kernel/util/net.go index acd607b80..9e79b6295 100644 --- a/kernel/util/net.go +++ b/kernel/util/net.go @@ -171,10 +171,10 @@ func initHttpClient() { } func ParsePort(portString string) (uint16, error) { - if port, err := strconv.ParseUint(portString, 10, 16); err != nil { + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { logging.LogErrorf("parse port [%s] failed: %s", portString, err) return 0, err - } else { - return uint16(port), nil } + return uint16(port), nil }