mirror of
https://github.com/containrrr/watchtower.git
synced 2025-09-21 21:30:48 +02:00
fix(api): return appropriate status for unauthorized requests (#1116)
This commit is contained in:
parent
c0fd77d357
commit
81036b078b
3 changed files with 71 additions and 9 deletions
|
@ -25,9 +25,12 @@ func New(token string) *API {
|
|||
// RequireToken is wrapper around http.HandleFunc that checks token validity
|
||||
func (api *API) RequireToken(fn http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != fmt.Sprintf("Bearer %s", api.Token) {
|
||||
log.Tracef("Invalid token \"%s\"", r.Header.Get("Authorization"))
|
||||
log.Tracef("Expected token to be \"%s\"", api.Token)
|
||||
auth := r.Header.Get("Authorization")
|
||||
want := fmt.Sprintf("Bearer %s", api.Token)
|
||||
if auth != want {
|
||||
log.Tracef("Invalid Authorization header \"%s\"", auth)
|
||||
log.Tracef("Expected Authorization header to be \"%s\"", want)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
log.Debug("Valid token found.")
|
||||
|
|
65
pkg/api/api_test.go
Normal file
65
pkg/api/api_test.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const (
|
||||
token = "123123123"
|
||||
)
|
||||
|
||||
func TestAPI(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "API Suite")
|
||||
}
|
||||
|
||||
var _ = Describe("API", func() {
|
||||
api := New(token)
|
||||
|
||||
Describe("RequireToken middleware", func() {
|
||||
It("should return 401 Unauthorized when token is not provided", func() {
|
||||
handlerFunc := api.RequireToken(testHandler)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/hello", nil)
|
||||
|
||||
handlerFunc(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
|
||||
It("should return 401 Unauthorized when token is invalid", func() {
|
||||
handlerFunc := api.RequireToken(testHandler)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/hello", nil)
|
||||
req.Header.Set("Authorization", "Bearer 123")
|
||||
|
||||
handlerFunc(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
|
||||
It("should return 200 OK when token is valid", func() {
|
||||
handlerFunc := api.RequireToken(testHandler)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/hello", nil)
|
||||
req.Header.Set("Authorization", "Bearer " + token)
|
||||
|
||||
handlerFunc(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func testHandler(w http.ResponseWriter, req *http.Request) {
|
||||
_, _ = io.WriteString(w, "Hello!")
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue