feat(tls): implement dynamic TLS certificate reloading with file watching (#3792)

Fixes issue #3747

Signed-off-by: Ramkumar Chinchani <rchincha.dev@gmail.com>
This commit is contained in:
Ramkumar Chinchani
2026-02-15 13:01:50 -08:00
committed by GitHub
parent 2c110d2c20
commit 47659c11b2
9 changed files with 1372 additions and 5 deletions
+2
View File
@@ -204,4 +204,6 @@ var (
ErrOIDCEmptyValidationMsg = errors.New("validation error message is empty")
ErrOIDCValidationFailed = errors.New("OIDC claim validation failed")
ErrOIDCAudienceMismatch = errors.New("token audience does not match any of the expected audiences")
ErrCertificateNotLoaded = errors.New("tls certificate not yet loaded")
ErrCertificateWatcherAlreadyRunning = errors.New("certificate watcher is already running")
)
+44 -1
View File
@@ -12,6 +12,7 @@ import (
"os"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/gorilla/mux"
@@ -58,6 +59,8 @@ type Controller struct {
Healthz *common.Healthz
// runtime params
chosenPort int // kernel-chosen port
// TLS certificate management
TlsWatcher atomic.Pointer[TlsConfigWatcher]
}
func NewController(appConfig *config.Config) *Controller {
@@ -242,6 +245,12 @@ func (c *Controller) Run() error {
MinVersion: tls.VersionTLS12,
}
// Load CA certificate for mTLS client verification
// Note: CA certificate is loaded statically. Unlike the server certificate which is reloaded
// dynamically through the file watcher, CA certificate changes require a server restart.
// If dynamic CA certificate rotation is needed in the future (e.g., for rotating mTLS client certificates),
// the TlsConfigWatcher can be extended to also monitor and reload the CA certificate and update
// server.TLSConfig.ClientCAs accordingly.
if tlsConfig.CACert != "" {
caCert, err := os.ReadFile(tlsConfig.CACert)
if err != nil {
@@ -264,9 +273,37 @@ func (c *Controller) Run() error {
server.TLSConfig.ClientCAs = caCertPool
}
// Store TLS config paths in watcher for dynamic reloading
tlsCertPath := tlsConfig.Cert
tlsKeyPath := tlsConfig.Key
// Create and start certificate watcher
// Note: The watcher is stored in c.TlsWatcher before calling Start() so it can be properly
// cleaned up during Shutdown even if Start fails. The Stop() method gracefully handles
// the case where Start() was never called or failed by checking if the done channel is nil.
watcher := NewTlsConfigWatcher(tlsCertPath, tlsKeyPath, c.Log)
c.TlsWatcher.Store(watcher)
defer watcher.Stop()
// Load initial certificate
if err := watcher.ReloadCertificate(); err != nil {
c.Log.Error().Err(err).Msg("failed to load initial certificate")
return err
}
// Start file watching for certificate reloading
if err := watcher.Start(); err != nil {
c.Log.Warn().Err(err).Msg("failed to start certificate watcher, will use fallback polling")
}
// Set GetCertificate callback for dynamic certificate reloading
server.TLSConfig.GetCertificate = watcher.GetCertificate
c.Healthz.Ready()
return server.ServeTLS(listener, tlsConfig.Cert, tlsConfig.Key)
// Pass empty strings to ServeTLS since GetCertificate handles certificate loading
return server.ServeTLS(listener, "", "")
}
c.Healthz.Ready()
@@ -455,6 +492,12 @@ func (c *Controller) LoadNewConfig(newConfig *config.Config) {
}
func (c *Controller) Shutdown() {
// Stop certificate watcher if it's running
watcher := c.TlsWatcher.Load()
if watcher != nil {
watcher.Stop()
}
// stop all background tasks
c.StopBackgroundTasks()
+405
View File
@@ -0,0 +1,405 @@
//go:build sync && scrub && metrics && search && lint && userprefs && mgmt && imagetrust && ui
package api
import (
goerrors "errors"
"os"
"path"
"sync"
"testing"
"time"
"zotregistry.dev/zot/v2/pkg/log"
tlsutils "zotregistry.dev/zot/v2/pkg/test/tls"
)
var errGetCertificateFailed = goerrors.New("GetCertificate failed")
func TestReloadCertificateStatFailureKeepsModTimes(t *testing.T) {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
certPath := path.Join(tempDir, "cert.pem")
keyPath := path.Join(tempDir, "key.pem")
caOpts := &tlsutils.CertificateOptions{
CommonName: "*",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
if err != nil {
t.Fatalf("failed to generate CA cert: %v", err)
}
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "*",
OrganizationalUnit: "TestServer",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
t.Fatalf("failed to generate server cert: %v", err)
}
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
if err := watcher.ReloadCertificate(); err != nil {
t.Fatalf("failed to load initial certificate: %v", err)
}
watcher.mu.RLock()
initialCertModTime := watcher.tlsCertModTime
initialKeyModTime := watcher.tlsKeyModTime
watcher.mu.RUnlock()
if initialCertModTime.IsZero() || initialKeyModTime.IsZero() {
t.Fatal("expected initial mod times to be set")
}
originalStat := tlsFileStat
tlsFileStat = func(string) (os.FileInfo, error) {
return nil, os.ErrNotExist
}
t.Cleanup(func() {
tlsFileStat = originalStat
})
time.Sleep(10 * time.Millisecond)
if err := watcher.ReloadCertificate(); err != nil {
t.Fatalf("unexpected reload error when stat fails: %v", err)
}
watcher.mu.RLock()
updatedCertModTime := watcher.tlsCertModTime
updatedKeyModTime := watcher.tlsKeyModTime
watcher.mu.RUnlock()
if !updatedCertModTime.Equal(initialCertModTime) {
t.Fatal("certificate mod time changed despite stat failure")
}
if !updatedKeyModTime.Equal(initialKeyModTime) {
t.Fatal("key mod time changed despite stat failure")
}
}
func TestGetCertificateReloadConcurrency(t *testing.T) {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
certPath := path.Join(tempDir, "cert.pem")
keyPath := path.Join(tempDir, "key.pem")
caOpts := &tlsutils.CertificateOptions{
CommonName: "*",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
if err != nil {
t.Fatalf("failed to generate CA cert: %v", err)
}
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "*",
OrganizationalUnit: "TestServer",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
t.Fatalf("failed to generate server cert: %v", err)
}
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
if err := watcher.ReloadCertificate(); err != nil {
t.Fatalf("failed to load initial certificate: %v", err)
}
var wg sync.WaitGroup
errorCh := make(chan error, 32)
reloadWorker := func(iterations int) {
defer wg.Done()
for i := 0; i < iterations; i++ {
if err := watcher.ReloadCertificate(); err != nil {
errorCh <- err
return
}
}
}
getWorker := func(iterations int) {
defer wg.Done()
for i := 0; i < iterations; i++ {
cert, err := watcher.GetCertificate(nil)
if err != nil || cert == nil {
errorCh <- err
return
}
}
}
wg.Add(3)
go reloadWorker(50)
go getWorker(100)
go getWorker(100)
wg.Wait()
close(errorCh)
for err := range errorCh {
if err != nil {
t.Fatalf("concurrent TLS operations failed: %v", err)
}
}
}
func TestGetCertificateInitialLoadConcurrency(t *testing.T) {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
certPath := path.Join(tempDir, "cert.pem")
keyPath := path.Join(tempDir, "key.pem")
caOpts := &tlsutils.CertificateOptions{
CommonName: "*",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
if err != nil {
t.Fatalf("failed to generate CA cert: %v", err)
}
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "*",
OrganizationalUnit: "TestServer",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
t.Fatalf("failed to generate server cert: %v", err)
}
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
var wg sync.WaitGroup
results := make(chan error, 16)
worker := func(iterations int) {
defer wg.Done()
for i := 0; i < iterations; i++ {
cert, err := watcher.GetCertificate(nil)
if err == nil && cert == nil {
results <- errGetCertificateFailed
return
}
if err != nil {
results <- err
return
}
}
results <- nil
}
for range 8 {
wg.Add(1)
go worker(10)
}
wg.Wait()
close(results)
for err := range results {
if err != nil {
t.Fatalf("concurrent initial TLS load failed: %v", err)
}
}
}
func TestStartCertificateWatcherAddFailureFallsBack(t *testing.T) {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
keyPath := path.Join(tempDir, "key.pem")
if err := os.WriteFile(keyPath, []byte("key"), 0o600); err != nil {
t.Fatalf("failed to write key file: %v", err)
}
watcher := NewTlsConfigWatcher(path.Join(tempDir, "missing-cert.pem"), keyPath, logger)
if err := watcher.Start(); err == nil {
t.Fatal("expected watcher setup to fail for missing certificate file")
}
if watcher.UseInotify() {
t.Fatal("expected file watching to be disabled after watcher add failure")
}
}
func TestCertificateWatcherHandlesAtomicRename(t *testing.T) {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
certPath := path.Join(tempDir, "cert.pem")
keyPath := path.Join(tempDir, "key.pem")
// Generate initial certificates
caOpts := &tlsutils.CertificateOptions{
CommonName: "*",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
if err != nil {
t.Fatalf("failed to generate CA cert: %v", err)
}
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "*",
OrganizationalUnit: "TestServer-Initial",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
t.Fatalf("failed to generate initial server cert: %v", err)
}
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
if err := watcher.ReloadCertificate(); err != nil {
t.Fatalf("failed to load initial certificate: %v", err)
}
if err := watcher.Start(); err != nil {
t.Fatalf("failed to start certificate watcher: %v", err)
}
defer watcher.Stop()
watcher.mu.RLock()
initialModTime := watcher.tlsCertModTime
watcher.mu.RUnlock()
// Simulate atomic certificate replacement via rename (common pattern in k8s, cert-manager, etc.)
// Sleep long enough to ensure filesystem records different timestamps (at least 1 second
// to account for systems with second-level timestamp precision)
time.Sleep(1100 * time.Millisecond)
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.2",
CommonName: "*",
OrganizationalUnit: "TestServer-Rotated",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
// Write to temp files first
tempCertPath := certPath + ".tmp"
tempKeyPath := keyPath + ".tmp"
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, tempCertPath, tempKeyPath, newServerOpts); err != nil {
t.Fatalf("failed to generate temp server cert: %v", err)
}
// Atomic rename (this is how most certificate rotation works)
if err := os.Rename(tempCertPath, certPath); err != nil {
t.Fatalf("failed to rename cert: %v", err)
}
if err := os.Rename(tempKeyPath, keyPath); err != nil {
t.Fatalf("failed to rename key: %v", err)
}
// Poll until the certificate is reloaded or timeout is reached.
// This accounts for fsnotify delivery time plus the 150ms debounce interval,
// which can exceed fixed sleeps on slower systems.
deadline := time.Now().Add(2 * time.Second)
var newModTime time.Time
reloaded := false
for time.Now().Before(deadline) {
watcher.mu.RLock()
newModTime = watcher.tlsCertModTime
watcher.mu.RUnlock()
if newModTime.After(initialModTime) {
reloaded = true
break
}
time.Sleep(50 * time.Millisecond)
}
if !reloaded {
t.Fatalf("expected certificate to be reloaded after atomic rename within 2s, initial: %v, new: %v",
initialModTime, newModTime)
}
}
func TestCertificateWatcherCanRestart(t *testing.T) {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
certPath := path.Join(tempDir, "cert.pem")
keyPath := path.Join(tempDir, "key.pem")
caOpts := &tlsutils.CertificateOptions{
CommonName: "*",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
if err != nil {
t.Fatalf("failed to generate CA cert: %v", err)
}
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "*",
OrganizationalUnit: "TestServer",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
t.Fatalf("failed to generate server cert: %v", err)
}
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
// Load initial certificate
if err := watcher.ReloadCertificate(); err != nil {
t.Fatalf("failed to load initial certificate: %v", err)
}
// Start the watcher
if err := watcher.Start(); err != nil {
t.Fatalf("failed to start certificate watcher: %v", err)
}
// Verify UseInotify returns true while running
if !watcher.UseInotify() {
t.Fatalf("expected UseInotify() to be true after Start()")
}
// Stop the watcher
watcher.Stop()
time.Sleep(100 * time.Millisecond) // Allow cleanup to complete
// Verify UseInotify returns false after Stop
if watcher.UseInotify() {
t.Fatalf("expected UseInotify() to be false after Stop()")
}
// Verify we can call Start() again (should not return error)
if err := watcher.Start(); err != nil {
t.Fatalf("failed to restart certificate watcher: %v", err)
}
// Verify UseInotify is true again
if !watcher.UseInotify() {
t.Fatalf("expected UseInotify() to be true after restart")
}
watcher.Stop()
}
+160
View File
@@ -13607,3 +13607,163 @@ func readTagsFromStorage(rootDir, repoName string, digest godigest.Digest) ([]st
return result, nil
}
var errGetCertificateFailed = goerrors.New("GetCertificate failed")
func TestDynamicTLSCertificateReloading(t *testing.T) {
Convey("Test dynamic TLS certificate reloading", t, func() {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
// Generate initial certificate and key
certPath := path.Join(tempDir, "cert.pem")
keyPath := path.Join(tempDir, "key.pem")
caOpts := &tlsutils.CertificateOptions{
CommonName: "*",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
So(err, ShouldBeNil)
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "*",
OrganizationalUnit: "TestServer",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts)
So(err, ShouldBeNil)
// Create a watcher and set certificate paths
watcher := api.NewTlsConfigWatcher(certPath, keyPath, logger)
// Test 1: Load initial certificate
Convey("Load initial certificate successfully", func() {
err := watcher.ReloadCertificate()
So(err, ShouldBeNil)
cert, err := watcher.GetCertificate(nil)
So(err, ShouldBeNil)
So(cert, ShouldNotBeNil)
})
// Test 2: GetCertificate returns the loaded certificate
Convey("GetCertificate returns the loaded certificate", func() {
err := watcher.ReloadCertificate()
So(err, ShouldBeNil)
cert, err := watcher.GetCertificate(nil)
So(err, ShouldBeNil)
So(cert, ShouldNotBeNil)
certAgain, err := watcher.GetCertificate(nil)
So(err, ShouldBeNil)
So(certAgain, ShouldEqual, cert)
})
// Test 3: Certificate change detection via stat-based fallback
Convey("Detect certificate change and reload via stat fallback", func() {
// Load initial certificate
err := watcher.ReloadCertificate()
So(err, ShouldBeNil)
oldCert, err := watcher.GetCertificate(nil)
So(err, ShouldBeNil)
So(oldCert, ShouldNotBeNil)
oldLeaf, err := x509.ParseCertificate(oldCert.Certificate[0])
So(err, ShouldBeNil)
So(oldLeaf, ShouldNotBeNil)
// Wait long enough to ensure different modification time on coarse timestamp filesystems
time.Sleep(1100 * time.Millisecond)
// Generate a new certificate
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.2",
CommonName: "*",
OrganizationalUnit: "TestServer-New",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
So(err, ShouldBeNil)
// Call GetCertificate which should trigger reload via stat-based checking
cert, err := watcher.GetCertificate(nil)
So(err, ShouldBeNil)
So(cert, ShouldNotBeNil)
newLeaf, err := x509.ParseCertificate(cert.Certificate[0])
So(err, ShouldBeNil)
So(newLeaf, ShouldNotBeNil)
// Verify that the certificate content was updated
So(newLeaf.Subject.OrganizationalUnit, ShouldNotResemble, oldLeaf.Subject.OrganizationalUnit)
})
// Test 4: checkCertificateModTime correctly detects changes
Convey("checkCertificateModTime correctly detects file modifications", func() {
// Load initial certificate
err := watcher.ReloadCertificate()
So(err, ShouldBeNil)
// No changes yet, should return false
needsReload := watcher.CheckCertificateModTime()
So(needsReload, ShouldBeFalse)
// Wait and modify certificate file
time.Sleep(100 * time.Millisecond)
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "test.example.com",
CommonName: "*",
OrganizationalUnit: "TestServer-Changed",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
So(err, ShouldBeNil)
// Now should detect the change
needsReload = watcher.CheckCertificateModTime()
So(needsReload, ShouldBeTrue)
})
// Test 5: LoadX509KeyPair error handling
Convey("Handle certificate loading errors gracefully", func() {
badWatcher := api.NewTlsConfigWatcher("/nonexistent/cert.pem", "/nonexistent/key.pem", logger)
err := badWatcher.ReloadCertificate()
So(err, ShouldNotBeNil)
})
// Test 6: Concurrent GetCertificate calls with certificate reloading
Convey("Handle concurrent GetCertificate calls safely", func() {
err := watcher.ReloadCertificate()
So(err, ShouldBeNil)
// Simulate concurrent calls
done := make(chan error, 5)
for range 5 {
go func() {
cert, err := watcher.GetCertificate(nil)
if err != nil || cert == nil {
done <- errGetCertificateFailed
} else {
done <- nil
}
}()
}
// Wait for all goroutines
for range 5 {
err := <-done
So(err, ShouldBeNil)
}
})
})
}
+7 -3
View File
@@ -588,7 +588,7 @@ func getReferrers(ctx context.Context, routeHandler *RouteHandler,
imgStore storageTypes.ImageStore, name string, digest godigest.Digest,
artifactTypes []string,
) (ispec.Index, error) {
if isSyncOnDemandEnabled(*routeHandler.c) {
if isSyncOnDemandEnabled(routeHandler.c) {
routeHandler.c.Log.Info().Str("repository", name).Str("reference", digest.String()).
Msg("trying to get updated referrers by syncing on demand")
@@ -2167,7 +2167,7 @@ func (rh *RouteHandler) getImageStore(name string) storageTypes.ImageStore {
func getImageManifest(ctx context.Context, routeHandler *RouteHandler, imgStore storageTypes.ImageStore, name,
reference string,
) ([]byte, godigest.Digest, string, error) {
syncEnabled := isSyncOnDemandEnabled(*routeHandler.c)
syncEnabled := isSyncOnDemandEnabled(routeHandler.c)
_, digestErr := godigest.Parse(reference)
if digestErr == nil {
@@ -2400,7 +2400,11 @@ func getBlobUploadLocation(url *url.URL, name string, digest godigest.Digest) st
return url.String()
}
func isSyncOnDemandEnabled(ctlr Controller) bool {
func isSyncOnDemandEnabled(ctlr *Controller) bool {
if ctlr == nil {
return false
}
extensionsConfig := ctlr.Config.CopyExtensionsConfig()
if extensionsConfig.IsSyncEnabled() &&
fmt.Sprintf("%v", ctlr.SyncOnDemand) != fmt.Sprintf("%v", nil) {
+472
View File
@@ -0,0 +1,472 @@
package api
import (
"crypto/tls"
"fmt"
"os"
"sync"
"sync/atomic"
"time"
"github.com/fsnotify/fsnotify"
"zotregistry.dev/zot/v2/errors"
"zotregistry.dev/zot/v2/pkg/log"
)
const (
// tlsCertificateEventDebounceInterval is the time window for coalescing multiple file change events
// into a single reload operation. This prevents redundant reloads when both cert and key files are modified.
tlsCertificateEventDebounceInterval = 150 * time.Millisecond
// tlsCertificateStatCheckInterval is the rate limit for stat-based certificate change detection.
tlsCertificateStatCheckInterval = 1 * time.Second
)
var tlsFileStat = os.Stat //nolint:gochecknoglobals // test hook for os.Stat
// TlsConfigWatcher watches TLS cert/key files and reloads certificates on change.
type TlsConfigWatcher struct {
mu sync.RWMutex
watcher *fsnotify.Watcher
done chan struct{}
stopOnce sync.Once
useInotify bool
certPath string
keyPath string
log log.Logger
debounceTimer *time.Timer
debounceMutex sync.Mutex
// Certificate management fields
tlsCert *tls.Certificate
tlsCertModTime time.Time
tlsKeyModTime time.Time
tlsCertReloadInProgress atomic.Bool
tlsCertLastStatCheckTime atomic.Int64 // Unix timestamp in nanoseconds for last stat check
}
// NewTlsConfigWatcher creates a TLS config watcher for the given cert and key paths.
func NewTlsConfigWatcher(certPath, keyPath string, logger log.Logger) *TlsConfigWatcher {
return &TlsConfigWatcher{
certPath: certPath,
keyPath: keyPath,
log: logger,
}
}
// Start begins watching the certificate and key files for changes.
// Returns an error if the watcher is already running. Safe to call multiple times if Stop is called
// between attempts.
func (w *TlsConfigWatcher) Start() error {
// Check if watcher is already running to prevent duplicate goroutines and resource leaks
w.mu.RLock()
if w.done != nil {
w.mu.RUnlock()
return errors.ErrCertificateWatcherAlreadyRunning
}
w.mu.RUnlock()
watcher, err := fsnotify.NewWatcher()
if err != nil {
w.log.Error().Err(err).Msg("failed to create fsnotify watcher")
w.disableUseInotify()
return err
}
w.log.Debug().Str("cert", w.certPath).Str("key", w.keyPath).
Msg("starting TLS certificate watcher")
if err := watcher.Add(w.certPath); err != nil {
w.log.Warn().Err(err).Str("cert", w.certPath).Msg("failed to watch certificate file")
_ = watcher.Close()
w.disableUseInotify()
return err
}
if err := watcher.Add(w.keyPath); err != nil {
w.log.Warn().Err(err).Str("key", w.keyPath).Msg("failed to watch key file")
_ = watcher.Close()
w.disableUseInotify()
return err
}
w.mu.Lock()
w.watcher = watcher
w.useInotify = true
w.done = make(chan struct{})
w.stopOnce = sync.Once{}
w.mu.Unlock()
go w.loop()
w.log.Info().Msg("TLS certificate watcher started using fsnotify")
return nil
}
// Stop signals the watcher to stop and returns once the signal is sent.
// Safe to call even if Start() was never called or failed - it will return early if the watcher
// goroutine is not running.
func (w *TlsConfigWatcher) Stop() {
w.mu.RLock()
done := w.done
w.mu.RUnlock()
if done == nil {
w.log.Debug().Msg("TLS certificate watcher stop requested with no active watcher")
return
}
w.stopOnce.Do(func() {
// Clean up any pending debounce timer
w.debounceMutex.Lock()
if w.debounceTimer != nil {
w.debounceTimer.Stop()
w.debounceTimer = nil
}
w.debounceMutex.Unlock()
// Atomically capture and reset watcher state
w.mu.Lock()
capturedWatcher := w.watcher
capturedDone := w.done
w.done = nil
w.watcher = nil
w.useInotify = false
w.mu.Unlock()
// Close fsnotify watcher to terminate the goroutine promptly
if capturedWatcher != nil {
_ = capturedWatcher.Close()
}
// Signal the goroutine to exit via the done channel
if capturedDone != nil {
close(capturedDone)
}
w.log.Debug().Msg("TLS certificate watcher stopped and state cleared")
})
}
// UseInotify reports whether file watching is active.
func (w *TlsConfigWatcher) UseInotify() bool {
w.mu.RLock()
useInotify := w.useInotify
w.mu.RUnlock()
return useInotify
}
func (w *TlsConfigWatcher) disableUseInotify() {
w.mu.Lock()
w.useInotify = false
w.mu.Unlock()
}
func (w *TlsConfigWatcher) loop() {
// Clear watcher state when loop exits so Start() can be called again.
// This serves as a safety net for cases where the goroutine exits naturally
// (e.g., channels closed unexpectedly) without Stop() being called.
defer func() {
w.mu.Lock()
// Close done channel if not already closed by Stop()
if w.done != nil {
close(w.done)
w.done = nil
}
if w.watcher != nil {
_ = w.watcher.Close()
w.watcher = nil
}
w.useInotify = false
// Note: w.stopOnce is reset by Start() when a new watcher is initialized,
// so we don't reset it here to avoid races with Stop() executing its callback
w.mu.Unlock()
w.log.Debug().Msg("TLS certificate watcher loop cleanup completed")
}()
w.mu.RLock()
watcher := w.watcher
done := w.done
w.mu.RUnlock()
if watcher == nil {
w.log.Debug().Msg("TLS certificate watcher loop exited: watcher not initialized")
return
}
for {
select {
case <-done:
w.log.Debug().Msg("TLS certificate watcher loop exited: stop signal received")
return
case event, ok := <-watcher.Events:
if !ok {
w.log.Debug().Msg("TLS certificate watcher loop exited: events channel closed")
return
}
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 {
w.log.Debug().Str("file", event.Name).Str("op", event.Op.String()).
Msg("certificate file change detected")
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
if !w.retryAddWatch(event.Name, watcher, done) {
w.log.Warn().Str("file", event.Name).
Msg("failed to re-add watch after retries, switching to stat-based polling")
w.disableUseInotify()
}
}
select {
case <-done:
w.log.Debug().Msg("TLS certificate watcher loop exited: stop signal received before reload")
return
default:
}
// Debounce multiple file events to coalesce cert and key changes into a single reload
w.scheduleReload()
}
case <-w.getDebounceChannel():
// Debounce timer expired, perform the reload
w.debounceMutex.Lock()
w.debounceTimer = nil
w.debounceMutex.Unlock()
select {
case <-done:
w.log.Debug().Msg("TLS certificate watcher loop exited: stop signal received before debounced reload")
return
default:
}
w.log.Debug().Str("cert", w.certPath).Str("key", w.keyPath).
Msg("reloading TLS certificate after debounced file change")
if err := w.ReloadCertificate(); err != nil {
w.log.Error().Err(err).Msg("failed to reload certificate on file change")
}
case err, ok := <-watcher.Errors:
if !ok {
w.log.Debug().Msg("TLS certificate watcher loop exited: errors channel closed")
return
}
w.log.Error().Err(err).Msg("failed to watch certificate files")
}
}
}
func (w *TlsConfigWatcher) getDebounceChannel() <-chan time.Time {
w.debounceMutex.Lock()
defer w.debounceMutex.Unlock()
if w.debounceTimer == nil {
return nil
}
return w.debounceTimer.C
}
func (w *TlsConfigWatcher) scheduleReload() {
w.debounceMutex.Lock()
defer w.debounceMutex.Unlock()
// If a reload is already pending, just reset the timer to restart the debounce window
if w.debounceTimer != nil {
if !w.debounceTimer.Stop() {
select {
case <-w.debounceTimer.C:
default:
}
}
w.debounceTimer.Reset(tlsCertificateEventDebounceInterval)
w.log.Debug().Msg("debounce timer reset for additional file change event")
return
}
// First event after debounce window closed, schedule a new reload.
// Use time.NewTimer instead of time.AfterFunc since we only care about the timer's
// channel firing in the select statement, not about executing a callback function.
w.debounceTimer = time.NewTimer(tlsCertificateEventDebounceInterval)
w.log.Debug().Str("interval", tlsCertificateEventDebounceInterval.String()).
Msg("debounce timer started for file change events")
}
func (w *TlsConfigWatcher) retryAddWatch(file string, watcher *fsnotify.Watcher, done <-chan struct{}) bool {
for attempt := range 5 {
select {
case <-done:
return false
case <-time.After(time.Duration(50*(attempt+1)) * time.Millisecond):
}
if err := watcher.Add(file); err == nil {
w.log.Debug().Str("file", file).Int("attempt", attempt+1).
Msg("re-added watch after file removal/rename")
return true
}
w.log.Debug().Str("file", file).Int("attempt", attempt+1).
Msg("retrying watch add after failure")
}
return false
}
// GetCertificate is a callback used by tls.Config that dynamically loads TLS certificates.
// This allows certificates to be reloaded when they change on disk without restarting the server.
// It uses fsnotify for file watching when available or falls back to stat-based checking.
func (w *TlsConfigWatcher) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
var needsReload bool
var cert *tls.Certificate
// First check if certificate is not yet loaded
w.mu.RLock()
cert = w.tlsCert
if cert == nil {
w.mu.RUnlock()
needsReload = true
} else {
w.mu.RUnlock()
useInotify := w.UseInotify()
// If file watching is not being used, perform stat-based fallback polling.
// Rate limit the stat checks to avoid I/O overhead on every TLS handshake.
if !useInotify {
now := time.Now().UnixNano()
lastCheckTime := w.tlsCertLastStatCheckTime.Load()
if now-lastCheckTime >= int64(tlsCertificateStatCheckInterval) {
needsReload = w.CheckCertificateModTime()
// Update the last check time only if we actually performed the check
w.tlsCertLastStatCheckTime.Store(now)
}
}
}
// Only reload if necessary or if certificate is not yet loaded.
// Use atomic flag to ensure only one goroutine performs the reload when multiple
// concurrent requests detect a stale certificate, preventing duplicate reload operations.
//
// If another goroutine is already performing a reload, wait for it to complete before
// returning to avoid spurious ErrCertificateNotLoaded errors during concurrent initial loads.
if needsReload {
// Only proceed with reload if no reload is already in progress
if w.tlsCertReloadInProgress.CompareAndSwap(false, true) {
defer w.tlsCertReloadInProgress.Store(false)
if err := w.ReloadCertificate(); err != nil {
return nil, fmt.Errorf("failed to reload certificate: %w", err)
}
} else {
// Another goroutine is reloading - wait for it to complete (up to 5 seconds)
// to avoid returning ErrCertificateNotLoaded during concurrent initial loads
deadline := time.Now().Add(5 * time.Second)
for w.tlsCertReloadInProgress.Load() && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
}
}
w.mu.RLock()
cert = w.tlsCert
w.mu.RUnlock()
if cert == nil {
return nil, errors.ErrCertificateNotLoaded
}
return cert, nil
}
// ReloadCertificate loads the TLS certificate and key from disk.
func (w *TlsConfigWatcher) ReloadCertificate() error {
cert, err := tls.LoadX509KeyPair(w.certPath, w.keyPath)
if err != nil {
w.log.Error().Err(err).Str("cert", w.certPath).Str("key", w.keyPath).
Msg("failed to load certificate and key pair")
return err
}
// Update modification times
certInfo, certStatErr := tlsFileStat(w.certPath)
if certStatErr != nil {
w.log.Warn().Err(certStatErr).Str("cert", w.certPath).
Msg("failed to stat certificate file")
}
keyInfo, keyStatErr := tlsFileStat(w.keyPath)
if keyStatErr != nil {
w.log.Warn().Err(keyStatErr).Str("key", w.keyPath).
Msg("failed to stat key file")
}
// Edge case: If both stat calls fail, we don't update the modification times.
// This prevents incorrectly treating a transient stat failure as "no change".
// However, this means CheckCertificateModTime could return false positives on the next call
// if either stat call fails again. Log a warning if this occurs.
if certStatErr != nil && keyStatErr != nil {
w.log.Warn().Msg("both cert and key stat failed during reload - mod times not updated, " +
"next stat-based check may return false positives")
}
w.mu.Lock()
w.tlsCert = &cert
if certInfo != nil {
w.tlsCertModTime = certInfo.ModTime()
}
if keyInfo != nil {
w.tlsKeyModTime = keyInfo.ModTime()
}
w.mu.Unlock()
w.log.Debug().Str("cert", w.certPath).Str("key", w.keyPath).
Msg("TLS certificate reloaded")
return nil
}
// CheckCertificateModTime checks if certificate or key files have been modified since last load.
// This is used as a fallback when inotify is not available.
func (w *TlsConfigWatcher) CheckCertificateModTime() bool {
certInfo, err := tlsFileStat(w.certPath)
if err != nil {
w.log.Error().Err(err).Str("cert", w.certPath).Msg("failed to stat certificate file")
return false
}
keyInfo, err := tlsFileStat(w.keyPath)
if err != nil {
w.log.Error().Err(err).Str("key", w.keyPath).Msg("failed to stat key file")
return false
}
w.mu.RLock()
certModTime := w.tlsCertModTime
keyModTime := w.tlsKeyModTime
w.mu.RUnlock()
// Check if either file has been modified since last load
if certInfo.ModTime().After(certModTime) || keyInfo.ModTime().After(keyModTime) {
w.log.Debug().Msg("certificate or key file modification detected via stat")
return true
}
return false
}
+1 -1
View File
@@ -15,7 +15,7 @@ tests=("pushpull" "pushpull_authn" "delete_images" "referrers" "metadata" "anony
"annotations" "detect_manifest_collision" "cve" "sync" "sync_docker" "sync_replica_cluster"
"scrub" "garbage_collect" "metrics" "metrics_minimal" "multiarch_index" "docker_compat" "redis_local" "redis_session_store"
"events_nats" "events_http" "events_nats_lint_failure" "events_http_lint_failure" "events_sink_failure" "events_config_decoding"
"fips140" "fips140_authn" "openid_claim_mapping" "upgrade" "upgrade_minimal")
"fips140" "fips140_authn" "openid_claim_mapping" "upgrade" "upgrade_minimal" "dynamic_tls")
for test in ${tests[*]}; do
${BATS} ${BATS_FLAGS} ${SCRIPTPATH}/${test}.bats > ${test}.log & pids+=($!)
+275
View File
@@ -0,0 +1,275 @@
# Note: Intended to be run as "make run-blackbox-tests" or "make run-blackbox-ci"
# Makefile target installs & checks all necessary tooling
# Extra tools that are not covered in Makefile target needs to be added in verify_prerequisites()
load helpers_zot
load ../port_helper
function verify_prerequisites {
if ! command -v curl >/dev/null 2>&1; then
echo "you need to install curl as a prerequisite to running the tests" >&3
return 1
fi
if ! command -v jq >/dev/null 2>&1; then
echo "you need to install jq as a prerequisite to running the tests" >&3
return 1
fi
if ! command -v openssl >/dev/null 2>&1; then
echo "you need to install openssl as a prerequisite to running the tests" >&3
return 1
fi
return 0
}
# Generate a self-signed certificate with the given CN
function generate_self_signed_cert() {
local cert_path=${1}
local key_path=${2}
local common_name=${3:-"localhost"}
local days=${4:-365}
openssl req -x509 -newkey rsa:2048 -keyout "${key_path}" -out "${cert_path}" \
-days ${days} -nodes \
-subj "/C=US/ST=Test/L=Test/O=Zot/CN=${common_name}"
}
# Wait for a condition to be true, polling up to max_attempts times with interval_seconds between attempts
# Usage: wait_for_condition <max_attempts> <interval_seconds> "<command>"
# Returns 0 on success, 1 on timeout
function wait_for_condition() {
local max_attempts=${1}
local interval=${2}
local condition_cmd=${3}
local attempt=1
while [ $attempt -le $max_attempts ]; do
if eval "${condition_cmd}"; then
echo "Condition met after $attempt attempts" >&3
return 0
fi
if [ $attempt -lt $max_attempts ]; then
sleep "${interval}"
fi
((attempt++))
done
echo "Condition timed out after $max_attempts attempts" >&3
return 1
}
function setup_file() {
# Verify prerequisites are available
if ! verify_prerequisites; then
exit 1
fi
# Download test data to folder common for the entire suite, not just this file
skopeo --insecure-policy copy --format=oci docker://ghcr.io/project-zot/test-images/busybox:1.36 oci:${TEST_DATA_DIR}/busybox:1.36
# Setup zot server with TLS
local zot_root_dir=${BATS_FILE_TMPDIR}/zot
local zot_config_file=${BATS_FILE_TMPDIR}/zot_config.json
local zot_cert_file=${BATS_FILE_TMPDIR}/server.cert
local zot_key_file=${BATS_FILE_TMPDIR}/server.key
zot_port=$(get_free_port_for_service "zot")
echo ${zot_port} > ${BATS_FILE_TMPDIR}/zot.port
mkdir -p ${zot_root_dir}
# Generate initial TLS certificate
generate_self_signed_cert "${zot_cert_file}" "${zot_key_file}" "127.0.0.1" 365
# Create zot config with TLS enabled
cat > ${zot_config_file}<<EOF
{
"distSpecVersion":"1.1.1",
"storage":{
"dedupe": true,
"gc": true,
"gcDelay": "1h",
"gcInterval": "6h",
"rootDirectory": "${zot_root_dir}"
},
"http": {
"address": "127.0.0.1",
"port": "${zot_port}",
"tls": {
"cert": "${zot_cert_file}",
"key": "${zot_key_file}"
}
},
"log":{
"level":"debug",
"output": "${BATS_FILE_TMPDIR}/zot.log"
}
}
EOF
echo ${zot_root_dir} >&3
zot_serve ${ZOT_PATH} ${zot_config_file}
# Wait for server to be ready by polling for connectivity
wait_for_condition 30 0.2 "curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog >/dev/null 2>&1"
}
function teardown() {
# conditionally printing on failure is possible from teardown but not from teardown_file
cat ${BATS_FILE_TMPDIR}/zot.log 2>/dev/null || true
}
function teardown_file() {
zot_stop_all
}
@test "TLS connection succeeds with self-signed certificate" {
zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port)
# Test with curl using insecure flag since we're using self-signed cert
run curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog
[ "$status" -eq 0 ]
}
@test "push image with TLS enabled" {
zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port)
# Use skopeo to push image over HTTPS with insecure TLS verification
run skopeo --insecure-policy copy --dest-tls-verify=false \
oci:${TEST_DATA_DIR}/busybox:1.36 \
docker://127.0.0.1:${zot_port}/busybox:1.36
[ "$status" -eq 0 ]
}
@test "pull image with TLS enabled" {
zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port)
local temp_oci_dir=${BATS_FILE_TMPDIR}/busybox-pulled
mkdir -p ${temp_oci_dir}
# Pull the pushed image back
run skopeo --insecure-policy copy --src-tls-verify=false \
docker://127.0.0.1:${zot_port}/busybox:1.36 \
oci:${temp_oci_dir}
[ "$status" -eq 0 ]
# Verify OCI image was downloaded
[ -f "${temp_oci_dir}/oci-layout" ]
}
@test "dynamic certificate reload: verify server uses new certificate after update" {
zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port)
local zot_cert_file=${BATS_FILE_TMPDIR}/server.cert
local zot_key_file=${BATS_FILE_TMPDIR}/server.key
# Get the certificate fingerprint before update
cert_fingerprint_before=$(openssl x509 -fingerprint -sha256 -noout -in "${zot_cert_file}" 2>/dev/null | cut -d'=' -f2)
server_fingerprint_before=$(openssl s_client -connect 127.0.0.1:${zot_port} -servername 127.0.0.1 -showcerts </dev/null 2>/dev/null \
| openssl x509 -fingerprint -sha256 -noout 2>/dev/null | cut -d'=' -f2)
[ -n "${server_fingerprint_before}" ]
# Keep fetching catalog to ensure server is responsive before cert update
wait_for_condition 10 0.5 "curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog >/dev/null 2>&1"
# Update the certificate with a new one
# This simulates a real-world scenario where certificates are renewed
generate_self_signed_cert "${zot_cert_file}" "${zot_key_file}" "127.0.0.1" 365
# Wait for file system changes to be visible and stat cache to expire
# (allows time for inotify to detect changes or stat-based check to trigger)
wait_for_condition 10 0.1 "[ \"$(openssl x509 -fingerprint -sha256 -noout -in \"${zot_cert_file}\" 2>/dev/null | cut -d'=' -f2)\" != \"${cert_fingerprint_before}\" ]"
# Request a new fingerprint after expecting the server to reload
wait_for_condition 20 0.2 "[ \"$(openssl s_client -connect 127.0.0.1:${zot_port} -servername 127.0.0.1 -showcerts </dev/null 2>/dev/null | openssl x509 -fingerprint -sha256 -noout 2>/dev/null | cut -d'=' -f2)\" != \"${server_fingerprint_before}\" ]" || true
# Make several requests to ensure server picks up the new certificate
# The server should automatically reload it through the GetCertificate callback
server_fingerprint_after=""
for i in {1..10}; do
server_fingerprint_after=$(openssl s_client -connect 127.0.0.1:${zot_port} -servername 127.0.0.1 -showcerts </dev/null 2>/dev/null \
| openssl x509 -fingerprint -sha256 -noout 2>/dev/null | cut -d'=' -f2)
run curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog
if [ "$status" -eq 0 ] && [ -n "${server_fingerprint_after}" ] && \
[ "${server_fingerprint_before}" != "${server_fingerprint_after}" ]; then
# Server is using the new certificate
echo "Request $i succeeded with new certificate" >&3
break
fi
if [ $i -lt 10 ]; then
sleep 0.2
fi
done
[ -n "${server_fingerprint_after}" ]
[ "${server_fingerprint_before}" != "${server_fingerprint_after}" ]
[ "$status" -eq 0 ]
}
@test "TLS works with multiple concurrent connections after certificate reload" {
zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port)
local zot_cert_file=${BATS_FILE_TMPDIR}/server.cert
local zot_key_file=${BATS_FILE_TMPDIR}/server.key
# Regenerate certificate to trigger reload
generate_self_signed_cert "${zot_cert_file}" "${zot_key_file}" "127.0.0.1" 365
# Wait for certificate to be reloaded by making requests
wait_for_condition 20 0.2 "curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog >/dev/null 2>&1"
# Test multiple concurrent requests
local failed=0
local pids=()
for i in {1..5}; do
(curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog > /dev/null 2>&1) &
pids+=($!)
done
# Wait for all background requests to complete
for pid in "${pids[@]}"; do
if ! wait "$pid"; then
failed=$((failed + 1))
fi
done
[ "$failed" -eq 0 ]
# If any failed, the test will fail
# Check that at least one request succeeds by making one more
run curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog
[ "$status" -eq 0 ]
}
@test "certificate reload doesn't require server restart" {
zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port)
local zot_cert_file=${BATS_FILE_TMPDIR}/server.cert
local zot_key_file=${BATS_FILE_TMPDIR}/server.key
# Get initial server PID
local zot_pid=$(cat ${BATS_FILE_TMPDIR}/zot.pid | awk '{print $1}')
# Make a request to establish the server is running
run curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog
[ "$status" -eq 0 ]
# Verify server is still running with same PID
kill -0 ${zot_pid} 2>/dev/null
[ "$?" -eq 0 ]
# Update certificate multiple times
for iteration in {1..3}; do
generate_self_signed_cert "${zot_cert_file}" "${zot_key_file}" "127.0.0.1" 365
# Wait for server to reload the new certificate
wait_for_condition 20 0.2 "curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog >/dev/null 2>&1" || true
# Server should still be running with the same PID
kill -0 ${zot_pid} 2>/dev/null
[ "$?" -eq 0 ]
# Requests should still work
run curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog
[ "$status" -eq 0 ]
done
}
+6
View File
@@ -448,5 +448,11 @@
"begin": 11510,
"end": 11519
}
},
"blackbox/dynamic_tls.bats": {
"zot": {
"begin": 11520,
"end": 11529
}
}
}