mirror of
https://github.com/containrrr/watchtower.git
synced 2025-12-16 23:20:12 +01:00
feat(registry): add support for custom CA certificates and TLS validation
- Introduced `--registry-ca` and `--registry-ca-validate` flags for configuring TLS verification with private registries. - Implemented in-memory token caching with expiration handling. - Updated documentation to reflect new CLI options and usage examples. - Added tests for token cache concurrency and expiry behavior.
This commit is contained in:
parent
76f9cea516
commit
e1f67fc3d0
18 changed files with 738 additions and 17 deletions
|
|
@ -8,6 +8,8 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/containrrr/watchtower/pkg/registry/helpers"
|
||||
"github.com/containrrr/watchtower/pkg/types"
|
||||
|
|
@ -75,12 +77,20 @@ func GetChallengeRequest(URL url.URL) (*http.Request, error) {
|
|||
// GetBearerHeader tries to fetch a bearer token from the registry based on the challenge instructions
|
||||
func GetBearerHeader(challenge string, imageRef ref.Named, registryAuth string) (string, error) {
|
||||
client := http.Client{}
|
||||
authURL, err := GetAuthURL(challenge, imageRef)
|
||||
|
||||
authURL, err := GetAuthURL(challenge, imageRef)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Build cache key from the auth realm, service and scope
|
||||
cacheKey := authURL.String()
|
||||
|
||||
// Check cache first
|
||||
if token := getCachedToken(cacheKey); token != "" {
|
||||
return fmt.Sprintf("Bearer %s", token), nil
|
||||
}
|
||||
|
||||
var r *http.Request
|
||||
if r, err = http.NewRequest("GET", authURL.String(), nil); err != nil {
|
||||
return "", err
|
||||
|
|
@ -88,8 +98,6 @@ func GetBearerHeader(challenge string, imageRef ref.Named, registryAuth string)
|
|||
|
||||
if registryAuth != "" {
|
||||
logrus.Debug("Credentials found.")
|
||||
// CREDENTIAL: Uncomment to log registry credentials
|
||||
// logrus.Tracef("Credentials: %v", registryAuth)
|
||||
r.Header.Add("Authorization", fmt.Sprintf("Basic %s", registryAuth))
|
||||
} else {
|
||||
logrus.Debug("No credentials found.")
|
||||
|
|
@ -99,6 +107,7 @@ func GetBearerHeader(challenge string, imageRef ref.Named, registryAuth string)
|
|||
if authResponse, err = client.Do(r); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer authResponse.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(authResponse.Body)
|
||||
tokenResponse := &types.TokenResponse{}
|
||||
|
|
@ -108,9 +117,54 @@ func GetBearerHeader(challenge string, imageRef ref.Named, registryAuth string)
|
|||
return "", err
|
||||
}
|
||||
|
||||
// Cache token if ExpiresIn provided
|
||||
if tokenResponse.Token != "" {
|
||||
storeToken(cacheKey, tokenResponse.Token, tokenResponse.ExpiresIn)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Bearer %s", tokenResponse.Token), nil
|
||||
}
|
||||
|
||||
// token cache implementation
|
||||
type cachedToken struct {
|
||||
token string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
tokenCache = map[string]cachedToken{}
|
||||
tokenCacheMu = &sync.Mutex{}
|
||||
)
|
||||
|
||||
// now is a package-level function returning current time. It is a variable so tests
|
||||
// can override it for deterministic behavior.
|
||||
var now = time.Now
|
||||
|
||||
// getCachedToken returns token string if present and not expired, otherwise empty
|
||||
func getCachedToken(key string) string {
|
||||
tokenCacheMu.Lock()
|
||||
defer tokenCacheMu.Unlock()
|
||||
if ct, ok := tokenCache[key]; ok {
|
||||
if ct.expiresAt.IsZero() || now().Before(ct.expiresAt) {
|
||||
return ct.token
|
||||
}
|
||||
// expired
|
||||
delete(tokenCache, key)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// storeToken stores token with optional ttl (seconds). ttl<=0 means no expiry.
|
||||
func storeToken(key, token string, ttl int) {
|
||||
tokenCacheMu.Lock()
|
||||
defer tokenCacheMu.Unlock()
|
||||
ct := cachedToken{token: token}
|
||||
if ttl > 0 {
|
||||
ct.expiresAt = now().Add(time.Duration(ttl) * time.Second)
|
||||
}
|
||||
tokenCache[key] = ct
|
||||
}
|
||||
|
||||
// GetAuthURL from the instructions in the challenge
|
||||
func GetAuthURL(challenge string, imageRef ref.Named) (*url.URL, error) {
|
||||
loweredChallenge := strings.ToLower(challenge)
|
||||
|
|
|
|||
101
pkg/registry/auth/auth_cache_concurrency_test.go
Normal file
101
pkg/registry/auth/auth_cache_concurrency_test.go
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Test concurrent stores and gets to ensure the mutex protects the cache
|
||||
func TestTokenCacheConcurrentStoreAndGet(t *testing.T) {
|
||||
// reset cache safely
|
||||
tokenCacheMu.Lock()
|
||||
tokenCache = map[string]cachedToken{}
|
||||
tokenCacheMu.Unlock()
|
||||
|
||||
origNow := now
|
||||
defer func() { now = origNow }()
|
||||
now = time.Now
|
||||
|
||||
key := "concurrent-key"
|
||||
token := "tok-concurrent"
|
||||
|
||||
var wg sync.WaitGroup
|
||||
storeers := 50
|
||||
getters := 50
|
||||
iters := 100
|
||||
|
||||
for i := 0; i < storeers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iters; j++ {
|
||||
storeToken(key, token, 0)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < getters; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iters; j++ {
|
||||
_ = getCachedToken(key)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if got := getCachedToken(key); got != token {
|
||||
t.Fatalf("expected token %q, got %q", token, got)
|
||||
}
|
||||
}
|
||||
|
||||
// Test concurrent access while token expires: readers run while time is advanced
|
||||
func TestTokenCacheConcurrentExpiry(t *testing.T) {
|
||||
// reset cache safely
|
||||
tokenCacheMu.Lock()
|
||||
tokenCache = map[string]cachedToken{}
|
||||
tokenCacheMu.Unlock()
|
||||
|
||||
// Make now controllable and thread-safe
|
||||
origNow := now
|
||||
defer func() { now = origNow }()
|
||||
|
||||
base := time.Now()
|
||||
var mu sync.Mutex
|
||||
current := base
|
||||
now = func() time.Time {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return current
|
||||
}
|
||||
|
||||
key := "concurrent-expire"
|
||||
storeToken(key, "t", 1)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
readers := 100
|
||||
|
||||
for i := 0; i < readers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 100; j++ {
|
||||
_ = getCachedToken(key)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// advance time beyond ttl
|
||||
mu.Lock()
|
||||
current = current.Add(2 * time.Second)
|
||||
mu.Unlock()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if got := getCachedToken(key); got != "" {
|
||||
t.Fatalf("expected token to be expired, got %q", got)
|
||||
}
|
||||
}
|
||||
54
pkg/registry/auth/auth_cache_test.go
Normal file
54
pkg/registry/auth/auth_cache_test.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTokenCacheStoreAndGetHitAndMiss(t *testing.T) {
|
||||
// save and restore original now
|
||||
origNow := now
|
||||
defer func() { now = origNow }()
|
||||
|
||||
// deterministic fake time
|
||||
base := time.Date(2025, time.November, 13, 12, 0, 0, 0, time.UTC)
|
||||
now = func() time.Time { return base }
|
||||
|
||||
key := "https://auth.example.com/?service=example&scope=repository:repo:pull"
|
||||
// ensure empty at start
|
||||
if got := getCachedToken(key); got != "" {
|
||||
t.Fatalf("expected empty cache initially, got %q", got)
|
||||
}
|
||||
|
||||
// store with no expiry (ttl <= 0)
|
||||
storeToken(key, "tok-123", 0)
|
||||
if got := getCachedToken(key); got != "tok-123" {
|
||||
t.Fatalf("expected token tok-123, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenCacheExpiry(t *testing.T) {
|
||||
// save and restore original now
|
||||
origNow := now
|
||||
defer func() { now = origNow }()
|
||||
|
||||
// deterministic fake time that can be moved forward
|
||||
base := time.Date(2025, time.November, 13, 12, 0, 0, 0, time.UTC)
|
||||
current := base
|
||||
now = func() time.Time { return current }
|
||||
|
||||
key := "https://auth.example.com/?service=example&scope=repository:repo2:pull"
|
||||
// store with short ttl (1 second)
|
||||
storeToken(key, "short-tok", 1)
|
||||
|
||||
if got := getCachedToken(key); got != "short-tok" {
|
||||
t.Fatalf("expected token short-tok immediately after store, got %q", got)
|
||||
}
|
||||
|
||||
// advance time beyond ttl
|
||||
current = current.Add(2 * time.Second)
|
||||
|
||||
if got := getCachedToken(key); got != "" {
|
||||
t.Fatalf("expected token to be expired and removed, got %q", got)
|
||||
}
|
||||
}
|
||||
|
|
@ -12,6 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/containrrr/watchtower/internal/meta"
|
||||
"github.com/containrrr/watchtower/pkg/registry"
|
||||
"github.com/containrrr/watchtower/pkg/registry/auth"
|
||||
"github.com/containrrr/watchtower/pkg/registry/manifest"
|
||||
"github.com/containrrr/watchtower/pkg/types"
|
||||
|
|
@ -76,19 +77,7 @@ func TransformAuth(registryAuth string) string {
|
|||
|
||||
// GetDigest from registry using a HEAD request to prevent rate limiting
|
||||
func GetDigest(url string, token string) (string, error) {
|
||||
tr := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
tr := newTransport()
|
||||
client := &http.Client{Transport: tr}
|
||||
|
||||
req, _ := http.NewRequest("HEAD", url, nil)
|
||||
|
|
@ -124,3 +113,35 @@ func GetDigest(url string, token string) (string, error) {
|
|||
}
|
||||
return res.Header.Get(ContentDigestHeader), nil
|
||||
}
|
||||
|
||||
// newTransport constructs an *http.Transport used for registry HEAD/token requests.
|
||||
// It respects the package-level `registry.InsecureSkipVerify` toggle.
|
||||
func newTransport() *http.Transport {
|
||||
tr := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
certPool := registry.GetRegistryCertPool()
|
||||
if registry.InsecureSkipVerify {
|
||||
// Insecure mode requested: disable verification entirely
|
||||
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
} else if certPool != nil {
|
||||
// Create TLS config with custom root CAs merged into system pool
|
||||
tr.TLSClientConfig = &tls.Config{RootCAs: certPool}
|
||||
}
|
||||
return tr
|
||||
}
|
||||
|
||||
// NewTransportForTest exposes the transport construction for unit tests.
|
||||
func NewTransportForTest() *http.Transport {
|
||||
return newTransport()
|
||||
}
|
||||
|
|
|
|||
27
pkg/registry/digest/digest_transport_test.go
Normal file
27
pkg/registry/digest/digest_transport_test.go
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
package digest_test
|
||||
|
||||
import (
|
||||
"github.com/containrrr/watchtower/pkg/registry"
|
||||
"github.com/containrrr/watchtower/pkg/registry/digest"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Digest transport configuration", func() {
|
||||
AfterEach(func() {
|
||||
// Reset to default after each test
|
||||
registry.InsecureSkipVerify = false
|
||||
})
|
||||
|
||||
It("should have nil TLSClientConfig by default", func() {
|
||||
registry.InsecureSkipVerify = false
|
||||
tr := digest.NewTransportForTest()
|
||||
Expect(tr.TLSClientConfig).To(BeNil())
|
||||
})
|
||||
|
||||
It("should set TLSClientConfig when insecure flag is true", func() {
|
||||
registry.InsecureSkipVerify = true
|
||||
tr := digest.NewTransportForTest()
|
||||
Expect(tr.TLSClientConfig).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
|
|
@ -1,6 +1,9 @@
|
|||
package registry
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/containrrr/watchtower/pkg/registry/helpers"
|
||||
watchtowerTypes "github.com/containrrr/watchtower/pkg/types"
|
||||
ref "github.com/distribution/reference"
|
||||
|
|
@ -8,6 +11,18 @@ import (
|
|||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// InsecureSkipVerify controls whether registry HTTPS connections used for
|
||||
// manifest HEAD/token requests disable certificate verification. Default is false.
|
||||
// This is exposed so callers (e.g. CLI flag handling) can toggle it.
|
||||
var InsecureSkipVerify = false
|
||||
|
||||
// RegistryCABundle is an optional filesystem path to a PEM bundle that will be
|
||||
// used as additional trusted CAs when validating registry TLS certificates.
|
||||
var RegistryCABundle string
|
||||
|
||||
// registryCertPool caches the loaded cert pool when RegistryCABundle is set
|
||||
var registryCertPool *x509.CertPool
|
||||
|
||||
// GetPullOptions creates a struct with all options needed for pulling images from a registry
|
||||
func GetPullOptions(imageName string) (types.ImagePullOptions, error) {
|
||||
auth, err := EncodedAuth(imageName)
|
||||
|
|
@ -59,3 +74,29 @@ func WarnOnAPIConsumption(container watchtowerTypes.Container) bool {
|
|||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetRegistryCertPool returns a cert pool that includes system roots plus any
|
||||
// additional CAs provided via RegistryCABundle. The resulting pool is cached.
|
||||
func GetRegistryCertPool() *x509.CertPool {
|
||||
if RegistryCABundle == "" {
|
||||
return nil
|
||||
}
|
||||
if registryCertPool != nil {
|
||||
return registryCertPool
|
||||
}
|
||||
// Try to load file
|
||||
data, err := ioutil.ReadFile(RegistryCABundle)
|
||||
if err != nil {
|
||||
log.WithField("path", RegistryCABundle).Errorf("Failed to load registry CA bundle: %v", err)
|
||||
return nil
|
||||
}
|
||||
pool, err := x509.SystemCertPool()
|
||||
if err != nil || pool == nil {
|
||||
pool = x509.NewCertPool()
|
||||
}
|
||||
if ok := pool.AppendCertsFromPEM(data); !ok {
|
||||
log.WithField("path", RegistryCABundle).Warn("No certs appended from registry CA bundle; file may be empty or invalid PEM")
|
||||
}
|
||||
registryCertPool = pool
|
||||
return registryCertPool
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue