feat(registry): add support for custom CA certificates and TLS validation

- Introduced `--registry-ca` and `--registry-ca-validate` flags for configuring TLS verification with private registries.
- Implemented in-memory token caching with expiration handling.
- Updated documentation to reflect new CLI options and usage examples.
- Added tests for token cache concurrency and expiry behavior.
This commit is contained in:
kalvinparker 2025-11-14 14:30:37 +00:00
parent 76f9cea516
commit e1f67fc3d0
18 changed files with 738 additions and 17 deletions

View file

@ -8,6 +8,8 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/containrrr/watchtower/pkg/registry/helpers"
"github.com/containrrr/watchtower/pkg/types"
@ -75,12 +77,20 @@ func GetChallengeRequest(URL url.URL) (*http.Request, error) {
// GetBearerHeader tries to fetch a bearer token from the registry based on the challenge instructions
func GetBearerHeader(challenge string, imageRef ref.Named, registryAuth string) (string, error) {
client := http.Client{}
authURL, err := GetAuthURL(challenge, imageRef)
authURL, err := GetAuthURL(challenge, imageRef)
if err != nil {
return "", err
}
// Build cache key from the auth realm, service and scope
cacheKey := authURL.String()
// Check cache first
if token := getCachedToken(cacheKey); token != "" {
return fmt.Sprintf("Bearer %s", token), nil
}
var r *http.Request
if r, err = http.NewRequest("GET", authURL.String(), nil); err != nil {
return "", err
@ -88,8 +98,6 @@ func GetBearerHeader(challenge string, imageRef ref.Named, registryAuth string)
if registryAuth != "" {
logrus.Debug("Credentials found.")
// CREDENTIAL: Uncomment to log registry credentials
// logrus.Tracef("Credentials: %v", registryAuth)
r.Header.Add("Authorization", fmt.Sprintf("Basic %s", registryAuth))
} else {
logrus.Debug("No credentials found.")
@ -99,6 +107,7 @@ func GetBearerHeader(challenge string, imageRef ref.Named, registryAuth string)
if authResponse, err = client.Do(r); err != nil {
return "", err
}
defer authResponse.Body.Close()
body, _ := io.ReadAll(authResponse.Body)
tokenResponse := &types.TokenResponse{}
@ -108,9 +117,54 @@ func GetBearerHeader(challenge string, imageRef ref.Named, registryAuth string)
return "", err
}
// Cache token if ExpiresIn provided
if tokenResponse.Token != "" {
storeToken(cacheKey, tokenResponse.Token, tokenResponse.ExpiresIn)
}
return fmt.Sprintf("Bearer %s", tokenResponse.Token), nil
}
// token cache implementation
type cachedToken struct {
token string
expiresAt time.Time
}
var (
tokenCache = map[string]cachedToken{}
tokenCacheMu = &sync.Mutex{}
)
// now is a package-level function returning current time. It is a variable so tests
// can override it for deterministic behavior.
var now = time.Now
// getCachedToken returns token string if present and not expired, otherwise empty
func getCachedToken(key string) string {
tokenCacheMu.Lock()
defer tokenCacheMu.Unlock()
if ct, ok := tokenCache[key]; ok {
if ct.expiresAt.IsZero() || now().Before(ct.expiresAt) {
return ct.token
}
// expired
delete(tokenCache, key)
}
return ""
}
// storeToken stores token with optional ttl (seconds). ttl<=0 means no expiry.
func storeToken(key, token string, ttl int) {
tokenCacheMu.Lock()
defer tokenCacheMu.Unlock()
ct := cachedToken{token: token}
if ttl > 0 {
ct.expiresAt = now().Add(time.Duration(ttl) * time.Second)
}
tokenCache[key] = ct
}
// GetAuthURL from the instructions in the challenge
func GetAuthURL(challenge string, imageRef ref.Named) (*url.URL, error) {
loweredChallenge := strings.ToLower(challenge)

View file

@ -0,0 +1,101 @@
package auth
import (
"sync"
"testing"
"time"
)
// Test concurrent stores and gets to ensure the mutex protects the cache
func TestTokenCacheConcurrentStoreAndGet(t *testing.T) {
// reset cache safely
tokenCacheMu.Lock()
tokenCache = map[string]cachedToken{}
tokenCacheMu.Unlock()
origNow := now
defer func() { now = origNow }()
now = time.Now
key := "concurrent-key"
token := "tok-concurrent"
var wg sync.WaitGroup
storeers := 50
getters := 50
iters := 100
for i := 0; i < storeers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < iters; j++ {
storeToken(key, token, 0)
}
}()
}
for i := 0; i < getters; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < iters; j++ {
_ = getCachedToken(key)
}
}()
}
wg.Wait()
if got := getCachedToken(key); got != token {
t.Fatalf("expected token %q, got %q", token, got)
}
}
// Test concurrent access while token expires: readers run while time is advanced
func TestTokenCacheConcurrentExpiry(t *testing.T) {
// reset cache safely
tokenCacheMu.Lock()
tokenCache = map[string]cachedToken{}
tokenCacheMu.Unlock()
// Make now controllable and thread-safe
origNow := now
defer func() { now = origNow }()
base := time.Now()
var mu sync.Mutex
current := base
now = func() time.Time {
mu.Lock()
defer mu.Unlock()
return current
}
key := "concurrent-expire"
storeToken(key, "t", 1)
var wg sync.WaitGroup
readers := 100
for i := 0; i < readers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 100; j++ {
_ = getCachedToken(key)
}
}()
}
// advance time beyond ttl
mu.Lock()
current = current.Add(2 * time.Second)
mu.Unlock()
wg.Wait()
if got := getCachedToken(key); got != "" {
t.Fatalf("expected token to be expired, got %q", got)
}
}

View file

@ -0,0 +1,54 @@
package auth
import (
"testing"
"time"
)
func TestTokenCacheStoreAndGetHitAndMiss(t *testing.T) {
// save and restore original now
origNow := now
defer func() { now = origNow }()
// deterministic fake time
base := time.Date(2025, time.November, 13, 12, 0, 0, 0, time.UTC)
now = func() time.Time { return base }
key := "https://auth.example.com/?service=example&scope=repository:repo:pull"
// ensure empty at start
if got := getCachedToken(key); got != "" {
t.Fatalf("expected empty cache initially, got %q", got)
}
// store with no expiry (ttl <= 0)
storeToken(key, "tok-123", 0)
if got := getCachedToken(key); got != "tok-123" {
t.Fatalf("expected token tok-123, got %q", got)
}
}
func TestTokenCacheExpiry(t *testing.T) {
// save and restore original now
origNow := now
defer func() { now = origNow }()
// deterministic fake time that can be moved forward
base := time.Date(2025, time.November, 13, 12, 0, 0, 0, time.UTC)
current := base
now = func() time.Time { return current }
key := "https://auth.example.com/?service=example&scope=repository:repo2:pull"
// store with short ttl (1 second)
storeToken(key, "short-tok", 1)
if got := getCachedToken(key); got != "short-tok" {
t.Fatalf("expected token short-tok immediately after store, got %q", got)
}
// advance time beyond ttl
current = current.Add(2 * time.Second)
if got := getCachedToken(key); got != "" {
t.Fatalf("expected token to be expired and removed, got %q", got)
}
}

View file

@ -12,6 +12,7 @@ import (
"time"
"github.com/containrrr/watchtower/internal/meta"
"github.com/containrrr/watchtower/pkg/registry"
"github.com/containrrr/watchtower/pkg/registry/auth"
"github.com/containrrr/watchtower/pkg/registry/manifest"
"github.com/containrrr/watchtower/pkg/types"
@ -76,19 +77,7 @@ func TransformAuth(registryAuth string) string {
// GetDigest from registry using a HEAD request to prevent rate limiting
func GetDigest(url string, token string) (string, error) {
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
tr := newTransport()
client := &http.Client{Transport: tr}
req, _ := http.NewRequest("HEAD", url, nil)
@ -124,3 +113,35 @@ func GetDigest(url string, token string) (string, error) {
}
return res.Header.Get(ContentDigestHeader), nil
}
// newTransport constructs an *http.Transport used for registry HEAD/token requests.
// It respects the package-level `registry.InsecureSkipVerify` toggle.
func newTransport() *http.Transport {
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
certPool := registry.GetRegistryCertPool()
if registry.InsecureSkipVerify {
// Insecure mode requested: disable verification entirely
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
} else if certPool != nil {
// Create TLS config with custom root CAs merged into system pool
tr.TLSClientConfig = &tls.Config{RootCAs: certPool}
}
return tr
}
// NewTransportForTest exposes the transport construction for unit tests.
func NewTransportForTest() *http.Transport {
return newTransport()
}

View file

@ -0,0 +1,27 @@
package digest_test
import (
"github.com/containrrr/watchtower/pkg/registry"
"github.com/containrrr/watchtower/pkg/registry/digest"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Digest transport configuration", func() {
AfterEach(func() {
// Reset to default after each test
registry.InsecureSkipVerify = false
})
It("should have nil TLSClientConfig by default", func() {
registry.InsecureSkipVerify = false
tr := digest.NewTransportForTest()
Expect(tr.TLSClientConfig).To(BeNil())
})
It("should set TLSClientConfig when insecure flag is true", func() {
registry.InsecureSkipVerify = true
tr := digest.NewTransportForTest()
Expect(tr.TLSClientConfig).ToNot(BeNil())
})
})

View file

@ -1,6 +1,9 @@
package registry
import (
"crypto/x509"
"io/ioutil"
"github.com/containrrr/watchtower/pkg/registry/helpers"
watchtowerTypes "github.com/containrrr/watchtower/pkg/types"
ref "github.com/distribution/reference"
@ -8,6 +11,18 @@ import (
log "github.com/sirupsen/logrus"
)
// InsecureSkipVerify controls whether registry HTTPS connections used for
// manifest HEAD/token requests disable certificate verification. Default is false.
// This is exposed so callers (e.g. CLI flag handling) can toggle it.
var InsecureSkipVerify = false
// RegistryCABundle is an optional filesystem path to a PEM bundle that will be
// used as additional trusted CAs when validating registry TLS certificates.
var RegistryCABundle string
// registryCertPool caches the loaded cert pool when RegistryCABundle is set
var registryCertPool *x509.CertPool
// GetPullOptions creates a struct with all options needed for pulling images from a registry
func GetPullOptions(imageName string) (types.ImagePullOptions, error) {
auth, err := EncodedAuth(imageName)
@ -59,3 +74,29 @@ func WarnOnAPIConsumption(container watchtowerTypes.Container) bool {
return false
}
// GetRegistryCertPool returns a cert pool that includes system roots plus any
// additional CAs provided via RegistryCABundle. The resulting pool is cached.
func GetRegistryCertPool() *x509.CertPool {
if RegistryCABundle == "" {
return nil
}
if registryCertPool != nil {
return registryCertPool
}
// Try to load file
data, err := ioutil.ReadFile(RegistryCABundle)
if err != nil {
log.WithField("path", RegistryCABundle).Errorf("Failed to load registry CA bundle: %v", err)
return nil
}
pool, err := x509.SystemCertPool()
if err != nil || pool == nil {
pool = x509.NewCertPool()
}
if ok := pool.AppendCertsFromPEM(data); !ok {
log.WithField("path", RegistryCABundle).Warn("No certs appended from registry CA bundle; file may be empty or invalid PEM")
}
registryCertPool = pool
return registryCertPool
}