siyuan/kernel/server/proxy/publish.go
Yingyi / 颖逸 ff4d215f78
🎨 Add cookie-based auth in publish proxy (#15692)
* chore(publish-auth): Add TODO for cookie-based auth in publish proxy

A TODO comment was added to indicate future implementation of authentication using cookies in the PublishServiceTransport RoundTrip method.

* 🎨 Add session-based authentication for publish proxy

Introduces session management using cookies for the publish reverse proxy server. Adds session ID generation, storage, and validation in kernel/model/auth.go, and updates the proxy transport to check for valid sessions before falling back to basic authentication. Sets a session cookie upon successful basic auth login.

* 🐛 Fixed the issue of repeatedly setting cookies

* 🎨 Dynamically remove invalid session IDs

* ♻️ Revert changes in pnpm-lock.yaml
2025-08-28 16:20:12 +08:00

192 lines
5 KiB
Go

// SiYuan - Refactor your thinking
// Copyright (c) 2020-present, b3log.org
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package proxy
import (
"fmt"
"net"
"net/http"
"net/http/httputil"
"github.com/siyuan-note/logging"
"github.com/siyuan-note/siyuan/kernel/model"
"github.com/siyuan-note/siyuan/kernel/util"
)
type PublishServiceTransport struct{}
var (
Host = "0.0.0.0"
Port = "0"
listener net.Listener
transport = PublishServiceTransport{}
proxy = &httputil.ReverseProxy{
Rewrite: rewrite,
Transport: transport,
}
)
func InitPublishService() (uint16, error) {
model.InitAccounts()
if listener != nil {
if !model.Conf.Publish.Enable {
// 关闭发布服务
closePublishListener()
return 0, nil
}
if port, err := util.ParsePort(Port); err != nil {
return 0, err
} else if port != model.Conf.Publish.Port {
// 关闭原端口的发布服务
if err = closePublishListener(); err != nil {
return 0, err
}
// 重新启动新端口的发布服务
initPublishService()
}
} else {
if !model.Conf.Publish.Enable {
return 0, nil
}
// 启动新端口的发布服务
initPublishService()
}
return util.ParsePort(Port)
}
func initPublishService() (err error) {
if err = initPublishListener(); err == nil {
go startPublishReverseProxyService()
}
return
}
func initPublishListener() (err error) {
// Start new listener
listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", Host, model.Conf.Publish.Port))
if err != nil {
logging.LogErrorf("start listener failed: %s", err)
return
}
_, Port, err = net.SplitHostPort(listener.Addr().String())
if err != nil {
logging.LogErrorf("split host and port failed: %s", err)
return
}
return
}
func closePublishListener() (err error) {
listener_ := listener
listener = nil
if err = listener_.Close(); err != nil {
logging.LogErrorf("close listener %s failed: %s", listener_.Addr().String(), err)
listener = listener_
}
return
}
func startPublishReverseProxyService() {
logging.LogInfof("publish service [%s:%s] is running", Host, Port)
// 服务进行时一直阻塞
if err := http.Serve(listener, proxy); err != nil {
if listener != nil {
logging.LogErrorf("boot publish service failed: %s", err)
}
}
logging.LogInfof("publish service [%s:%s] is stopped", Host, Port)
}
func rewrite(r *httputil.ProxyRequest) {
r.SetURL(util.ServerURL)
r.SetXForwarded()
// r.Out.Host = r.In.Host // if desired
}
func (PublishServiceTransport) RoundTrip(request *http.Request) (response *http.Response, err error) {
if model.Conf.Publish.Auth.Enable {
// Session Auth
sessionIdCookie, cookieErr := request.Cookie(model.SessionIdCookieName)
if cookieErr == nil {
// Check session ID
sessionID := sessionIdCookie.Value
if username := model.GetBasicAuthUsernameBySessionID(sessionID); username != "" {
// Valid session
if account := model.GetBasicAuthAccount(username); account != nil {
// Valid account
request.Header.Set(model.XAuthTokenKey, account.Token)
response, err = http.DefaultTransport.RoundTrip(request)
return
} else {
// Invalid account, remove session
model.DeleteSession(sessionID)
}
}
}
// Basic Auth
username, password, ok := request.BasicAuth()
account := model.GetBasicAuthAccount(username)
if !ok ||
account == nil ||
account.Username == "" || // 匿名用户
account.Password != password {
return &http.Response{
StatusCode: http.StatusUnauthorized,
Status: http.StatusText(http.StatusUnauthorized),
Proto: request.Proto,
ProtoMajor: request.ProtoMajor,
ProtoMinor: request.ProtoMinor,
Request: request,
Header: http.Header{
model.BasicAuthHeaderKey: {model.BasicAuthHeaderValue},
},
Body: http.NoBody,
Close: false,
ContentLength: -1,
}, nil
} else {
// set session cookie
sessionID := model.GetNewSessionID()
cookie := &http.Cookie{
Name: model.SessionIdCookieName,
Value: sessionID,
Path: "/",
HttpOnly: true,
}
model.AddSession(sessionID, username)
// set JWT
request.Header.Set(model.XAuthTokenKey, account.Token)
response, err = http.DefaultTransport.RoundTrip(request)
response.Header.Add("Set-Cookie", cookie.String())
return
}
} else {
request.Header.Set(model.XAuthTokenKey, model.GetBasicAuthAccount("").Token)
response, err = http.DefaultTransport.RoundTrip(request)
return
}
}