diff --git a/kernel/model/session.go b/kernel/model/session.go index 96198f9ec..9cebb8c01 100644 --- a/kernel/model/session.go +++ b/kernel/model/session.go @@ -98,7 +98,7 @@ func LoginAuth(c *gin.Context) { if err := session.Save(c); nil != err { logging.LogErrorf("save session failed: " + err.Error()) - c.Status(500) + c.Status(http.StatusInternalServerError) return } return @@ -109,7 +109,7 @@ func LoginAuth(c *gin.Context) { workspaceSession.Captcha = gulu.Rand.String(7) if err := session.Save(c); nil != err { logging.LogErrorf("save session failed: " + err.Error()) - c.Status(500) + c.Status(http.StatusInternalServerError) return } } @@ -123,7 +123,7 @@ func GetCaptcha(c *gin.Context) { }) if nil != err { logging.LogErrorf("generates captcha failed: " + err.Error()) - c.Status(500) + c.Status(http.StatusInternalServerError) return } @@ -132,16 +132,16 @@ func GetCaptcha(c *gin.Context) { workspaceSession.Captcha = img.Text if err = session.Save(c); nil != err { logging.LogErrorf("save session failed: " + err.Error()) - c.Status(500) + c.Status(http.StatusInternalServerError) return } if err = img.WriteImage(c.Writer); nil != err { logging.LogErrorf("writes captcha image failed: " + err.Error()) - c.Status(500) + c.Status(http.StatusInternalServerError) return } - c.Status(200) + c.Status(http.StatusOK) } func CheckReadonly(c *gin.Context) { @@ -150,7 +150,7 @@ func CheckReadonly(c *gin.Context) { result.Code = -1 result.Msg = Conf.Language(34) result.Data = map[string]interface{}{"closeTimeout": 5000} - c.JSON(200, result) + c.JSON(http.StatusOK, result) c.Abort() return } @@ -158,38 +158,21 @@ func CheckReadonly(c *gin.Context) { func CheckAuth(c *gin.Context) { //logging.LogInfof("check auth for [%s]", c.Request.RequestURI) + localhost := util.IsLocalHost(c.Request.RemoteAddr) + // 未设置访问授权码 if "" == Conf.AccessAuthCode { - if origin := c.GetHeader("Origin"); "" != origin { - // Authenticate requests with the Origin header other than 127.0.0.1 https://github.com/siyuan-note/siyuan/issues/9180 - u, parseErr := url.Parse(origin) - if nil != parseErr { - logging.LogWarnf("parse origin [%s] failed: %s", origin, parseErr) - c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed: parse req header [Origin] failed"}) - c.Abort() - return - - } - - if "chrome-extension" == strings.ToLower(u.Scheme) { - c.Next() - return - } - - if !strings.HasPrefix(u.Host, util.LocalHost) && !strings.HasPrefix(u.Host, "[::1]") { - c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed: for security reasons, please set [Access authorization code] when using non-127.0.0.1 access\n\n为安全起见,使用非 127.0.0.1 访问时请设置 [访问授权码]"}) - c.Abort() - return - } - } - - if !strings.HasPrefix(c.Request.RemoteAddr, util.LocalHost) && !strings.HasPrefix(c.Request.RemoteAddr, "[::1]") { - // Authenticate requests of assets other than 127.0.0.1 https://github.com/siyuan-note/siyuan/issues/9388 - if strings.HasPrefix(c.Request.RequestURI, "/assets/") { - c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed: for security reasons, please set [Access authorization code] when using non-127.0.0.1 access\n\n为安全起见,使用非 127.0.0.1 访问时请设置 [访问授权码]"}) - c.Abort() - return - } + // Authenticate requests with the Origin header other than 127.0.0.1 https://github.com/siyuan-note/siyuan/issues/9180 + host := c.GetHeader("Host") + origin := c.GetHeader("Origin") + forwardedHost := c.GetHeader("X-Forwarded-Host") + if !localhost || + ("" != host && !util.IsLocalHost(host)) || + ("" != origin && !util.IsLocalOrigin(origin) && !strings.HasPrefix(origin, "chrome-extension://")) || + ("" != forwardedHost && !util.IsLocalHost(forwardedHost)) { + c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed: for security reasons, please set [Access authorization code] when using non-127.0.0.1 access\n\n为安全起见,使用非 127.0.0.1 访问时请设置 [访问授权码]"}) + c.Abort() + return } c.Next() @@ -206,7 +189,7 @@ func CheckAuth(c *gin.Context) { } // 放过来自本机的某些请求 - if strings.HasPrefix(c.Request.RemoteAddr, util.LocalHost) || strings.HasPrefix(c.Request.RemoteAddr, "[::1]") { + if localhost { if strings.HasPrefix(c.Request.RequestURI, "/assets/") { c.Next() return @@ -234,7 +217,7 @@ func CheckAuth(c *gin.Context) { return } - c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed"}) + c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed"}) c.Abort() return } @@ -247,7 +230,7 @@ func CheckAuth(c *gin.Context) { return } - c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed"}) + c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed"}) c.Abort() return } @@ -261,7 +244,7 @@ func CheckAuth(c *gin.Context) { userAgentHeader := c.GetHeader("User-Agent") if strings.HasPrefix(userAgentHeader, "SiYuan/") || strings.HasPrefix(userAgentHeader, "Mozilla/") { if "GET" != c.Request.Method { - c.JSON(401, map[string]interface{}{"code": -1, "msg": Conf.Language(156)}) + c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": Conf.Language(156)}) c.Abort() return } @@ -271,12 +254,13 @@ func CheckAuth(c *gin.Context) { queryParams.Set("to", c.Request.URL.String()) location.RawQuery = queryParams.Encode() location.Path = "/check-auth" - c.Redirect(302, location.String()) + + c.Redirect(http.StatusFound, location.String()) c.Abort() return } - c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed"}) + c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed"}) c.Abort() return } @@ -316,7 +300,7 @@ func Timing(c *gin.Context) { func Recover(c *gin.Context) { defer func() { logging.Recover() - c.Status(500) + c.Status(http.StatusInternalServerError) }() c.Next() diff --git a/kernel/util/net.go b/kernel/util/net.go index 2fb5d2618..f97a8b9a7 100644 --- a/kernel/util/net.go +++ b/kernel/util/net.go @@ -17,6 +17,7 @@ package util import ( + "net" "net/http" "net/url" "strings" @@ -31,6 +32,58 @@ import ( "github.com/siyuan-note/logging" ) +func ValidOptionalPort(port string) bool { + if port == "" { + return true + } + if port[0] != ':' { + return false + } + for _, b := range port[1:] { + if b < '0' || b > '9' { + return false + } + } + return true +} + +func SplitHost(host string) (hostname, port string) { + hostname = host + + colon := strings.LastIndexByte(hostname, ':') + if colon != -1 && ValidOptionalPort(hostname[colon:]) { + hostname, port = hostname[:colon], hostname[colon+1:] + } + + if strings.HasPrefix(hostname, "[") && strings.HasSuffix(hostname, "]") { + hostname = hostname[1 : len(hostname)-1] + } + + return +} + +func IsLocalHostname(hostname string) bool { + if "localhost" == hostname { + return true + } + if ip := net.ParseIP(hostname); nil != ip { + return ip.IsLoopback() + } + return false +} + +func IsLocalHost(host string) bool { + hostname, _ := SplitHost(host) + return IsLocalHostname(hostname) +} + +func IsLocalOrigin(origin string) bool { + if url, err := url.Parse(origin); nil == err { + return IsLocalHostname(url.Hostname()) + } + return false +} + func IsOnline(checkURL string, skipTlsVerify bool) bool { _, err := url.Parse(checkURL) if nil != err {