feat(registry): add support for custom CA certificates and TLS validation

- Introduced `--registry-ca` and `--registry-ca-validate` flags for configuring TLS verification with private registries.
- Implemented in-memory token caching with expiration handling.
- Updated documentation to reflect new CLI options and usage examples.
- Added tests for token cache concurrency and expiry behavior.
This commit is contained in:
kalvinparker 2025-11-14 14:30:37 +00:00
parent 76f9cea516
commit e1f67fc3d0
18 changed files with 738 additions and 17 deletions

View file

@ -8,6 +8,8 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/containrrr/watchtower/pkg/registry/helpers"
"github.com/containrrr/watchtower/pkg/types"
@ -75,12 +77,20 @@ func GetChallengeRequest(URL url.URL) (*http.Request, error) {
// GetBearerHeader tries to fetch a bearer token from the registry based on the challenge instructions
func GetBearerHeader(challenge string, imageRef ref.Named, registryAuth string) (string, error) {
client := http.Client{}
authURL, err := GetAuthURL(challenge, imageRef)
authURL, err := GetAuthURL(challenge, imageRef)
if err != nil {
return "", err
}
// Build cache key from the auth realm, service and scope
cacheKey := authURL.String()
// Check cache first
if token := getCachedToken(cacheKey); token != "" {
return fmt.Sprintf("Bearer %s", token), nil
}
var r *http.Request
if r, err = http.NewRequest("GET", authURL.String(), nil); err != nil {
return "", err
@ -88,8 +98,6 @@ func GetBearerHeader(challenge string, imageRef ref.Named, registryAuth string)
if registryAuth != "" {
logrus.Debug("Credentials found.")
// CREDENTIAL: Uncomment to log registry credentials
// logrus.Tracef("Credentials: %v", registryAuth)
r.Header.Add("Authorization", fmt.Sprintf("Basic %s", registryAuth))
} else {
logrus.Debug("No credentials found.")
@ -99,6 +107,7 @@ func GetBearerHeader(challenge string, imageRef ref.Named, registryAuth string)
if authResponse, err = client.Do(r); err != nil {
return "", err
}
defer authResponse.Body.Close()
body, _ := io.ReadAll(authResponse.Body)
tokenResponse := &types.TokenResponse{}
@ -108,9 +117,54 @@ func GetBearerHeader(challenge string, imageRef ref.Named, registryAuth string)
return "", err
}
// Cache token if ExpiresIn provided
if tokenResponse.Token != "" {
storeToken(cacheKey, tokenResponse.Token, tokenResponse.ExpiresIn)
}
return fmt.Sprintf("Bearer %s", tokenResponse.Token), nil
}
// token cache implementation
type cachedToken struct {
token string
expiresAt time.Time
}
var (
tokenCache = map[string]cachedToken{}
tokenCacheMu = &sync.Mutex{}
)
// now is a package-level function returning current time. It is a variable so tests
// can override it for deterministic behavior.
var now = time.Now
// getCachedToken returns token string if present and not expired, otherwise empty
func getCachedToken(key string) string {
tokenCacheMu.Lock()
defer tokenCacheMu.Unlock()
if ct, ok := tokenCache[key]; ok {
if ct.expiresAt.IsZero() || now().Before(ct.expiresAt) {
return ct.token
}
// expired
delete(tokenCache, key)
}
return ""
}
// storeToken stores token with optional ttl (seconds). ttl<=0 means no expiry.
func storeToken(key, token string, ttl int) {
tokenCacheMu.Lock()
defer tokenCacheMu.Unlock()
ct := cachedToken{token: token}
if ttl > 0 {
ct.expiresAt = now().Add(time.Duration(ttl) * time.Second)
}
tokenCache[key] = ct
}
// GetAuthURL from the instructions in the challenge
func GetAuthURL(challenge string, imageRef ref.Named) (*url.URL, error) {
loweredChallenge := strings.ToLower(challenge)

View file

@ -0,0 +1,101 @@
package auth
import (
"sync"
"testing"
"time"
)
// Test concurrent stores and gets to ensure the mutex protects the cache
func TestTokenCacheConcurrentStoreAndGet(t *testing.T) {
// reset cache safely
tokenCacheMu.Lock()
tokenCache = map[string]cachedToken{}
tokenCacheMu.Unlock()
origNow := now
defer func() { now = origNow }()
now = time.Now
key := "concurrent-key"
token := "tok-concurrent"
var wg sync.WaitGroup
storeers := 50
getters := 50
iters := 100
for i := 0; i < storeers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < iters; j++ {
storeToken(key, token, 0)
}
}()
}
for i := 0; i < getters; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < iters; j++ {
_ = getCachedToken(key)
}
}()
}
wg.Wait()
if got := getCachedToken(key); got != token {
t.Fatalf("expected token %q, got %q", token, got)
}
}
// Test concurrent access while token expires: readers run while time is advanced
func TestTokenCacheConcurrentExpiry(t *testing.T) {
// reset cache safely
tokenCacheMu.Lock()
tokenCache = map[string]cachedToken{}
tokenCacheMu.Unlock()
// Make now controllable and thread-safe
origNow := now
defer func() { now = origNow }()
base := time.Now()
var mu sync.Mutex
current := base
now = func() time.Time {
mu.Lock()
defer mu.Unlock()
return current
}
key := "concurrent-expire"
storeToken(key, "t", 1)
var wg sync.WaitGroup
readers := 100
for i := 0; i < readers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 100; j++ {
_ = getCachedToken(key)
}
}()
}
// advance time beyond ttl
mu.Lock()
current = current.Add(2 * time.Second)
mu.Unlock()
wg.Wait()
if got := getCachedToken(key); got != "" {
t.Fatalf("expected token to be expired, got %q", got)
}
}

View file

@ -0,0 +1,54 @@
package auth
import (
"testing"
"time"
)
func TestTokenCacheStoreAndGetHitAndMiss(t *testing.T) {
// save and restore original now
origNow := now
defer func() { now = origNow }()
// deterministic fake time
base := time.Date(2025, time.November, 13, 12, 0, 0, 0, time.UTC)
now = func() time.Time { return base }
key := "https://auth.example.com/?service=example&scope=repository:repo:pull"
// ensure empty at start
if got := getCachedToken(key); got != "" {
t.Fatalf("expected empty cache initially, got %q", got)
}
// store with no expiry (ttl <= 0)
storeToken(key, "tok-123", 0)
if got := getCachedToken(key); got != "tok-123" {
t.Fatalf("expected token tok-123, got %q", got)
}
}
func TestTokenCacheExpiry(t *testing.T) {
// save and restore original now
origNow := now
defer func() { now = origNow }()
// deterministic fake time that can be moved forward
base := time.Date(2025, time.November, 13, 12, 0, 0, 0, time.UTC)
current := base
now = func() time.Time { return current }
key := "https://auth.example.com/?service=example&scope=repository:repo2:pull"
// store with short ttl (1 second)
storeToken(key, "short-tok", 1)
if got := getCachedToken(key); got != "short-tok" {
t.Fatalf("expected token short-tok immediately after store, got %q", got)
}
// advance time beyond ttl
current = current.Add(2 * time.Second)
if got := getCachedToken(key); got != "" {
t.Fatalf("expected token to be expired and removed, got %q", got)
}
}