test(auth): improve concurrency tests for token cache

This commit is contained in:
kalvinparker 2025-11-14 14:57:52 +00:00
parent 44236f2a30
commit e0d4ec1b2c

View file

@ -1,101 +1,101 @@
package auth package auth
import ( import (
"sync" "sync"
"testing" "testing"
"time" "time"
) )
// Test concurrent stores and gets to ensure the mutex protects the cache // Test concurrent stores and gets to ensure the mutex protects the cache
func TestTokenCacheConcurrentStoreAndGet(t *testing.T) { func TestTokenCacheConcurrentStoreAndGet(t *testing.T) {
// reset cache safely // reset cache safely
tokenCacheMu.Lock() tokenCacheMu.Lock()
tokenCache = map[string]cachedToken{} tokenCache = map[string]cachedToken{}
tokenCacheMu.Unlock() tokenCacheMu.Unlock()
origNow := now origNow := now
defer func() { now = origNow }() defer func() { now = origNow }()
now = time.Now now = time.Now
key := "concurrent-key" key := "concurrent-key"
token := "tok-concurrent" token := "tok-concurrent"
var wg sync.WaitGroup var wg sync.WaitGroup
storeers := 50 storeers := 50
getters := 50 getters := 50
iters := 100 iters := 100
for i := 0; i < storeers; i++ { for i := 0; i < storeers; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
for j := 0; j < iters; j++ { for j := 0; j < iters; j++ {
storeToken(key, token, 0) storeToken(key, token, 0)
} }
}() }()
} }
for i := 0; i < getters; i++ { for i := 0; i < getters; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
for j := 0; j < iters; j++ { for j := 0; j < iters; j++ {
_ = getCachedToken(key) _ = getCachedToken(key)
} }
}() }()
} }
wg.Wait() wg.Wait()
if got := getCachedToken(key); got != token { if got := getCachedToken(key); got != token {
t.Fatalf("expected token %q, got %q", token, got) t.Fatalf("expected token %q, got %q", token, got)
} }
} }
// Test concurrent access while token expires: readers run while time is advanced // Test concurrent access while token expires: readers run while time is advanced
func TestTokenCacheConcurrentExpiry(t *testing.T) { func TestTokenCacheConcurrentExpiry(t *testing.T) {
// reset cache safely // reset cache safely
tokenCacheMu.Lock() tokenCacheMu.Lock()
tokenCache = map[string]cachedToken{} tokenCache = map[string]cachedToken{}
tokenCacheMu.Unlock() tokenCacheMu.Unlock()
// Make now controllable and thread-safe // Make now controllable and thread-safe
origNow := now origNow := now
defer func() { now = origNow }() defer func() { now = origNow }()
base := time.Now() base := time.Now()
var mu sync.Mutex var mu sync.Mutex
current := base current := base
now = func() time.Time { now = func() time.Time {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
return current return current
} }
key := "concurrent-expire" key := "concurrent-expire"
storeToken(key, "t", 1) storeToken(key, "t", 1)
var wg sync.WaitGroup var wg sync.WaitGroup
readers := 100 readers := 100
for i := 0; i < readers; i++ { for i := 0; i < readers; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
for j := 0; j < 100; j++ { for j := 0; j < 100; j++ {
_ = getCachedToken(key) _ = getCachedToken(key)
} }
}() }()
} }
// advance time beyond ttl // advance time beyond ttl
mu.Lock() mu.Lock()
current = current.Add(2 * time.Second) current = current.Add(2 * time.Second)
mu.Unlock() mu.Unlock()
wg.Wait() wg.Wait()
if got := getCachedToken(key); got != "" { if got := getCachedToken(key); got != "" {
t.Fatalf("expected token to be expired, got %q", got) t.Fatalf("expected token to be expired, got %q", got)
} }
} }