feat(tls): automatically reload tls certs

Fixes issue #3747

Currently, zot requires a restart whenever tls certs change, which can
occur whenever there are tls cert rotation etc.

This PR checks if the tls certs have be modified and if so reloads them
without restarting zot.

Signed-off-by: Ramkumar Chinchani <rchincha.dev@gmail.com>
This commit is contained in:
Ramkumar Chinchani
2026-02-01 12:32:03 -08:00
parent b9aad15ad0
commit 7ead92b82f
4 changed files with 847 additions and 2 deletions
+22 -2
View File
@@ -54,6 +54,7 @@ type Controller struct {
HTPasswd *HTPasswd
HTPasswdWatcher *HTPasswdWatcher
LDAPClient *LDAPClient
CertReloader *CertReloader
taskScheduler *scheduler.Scheduler
Healthz *common.Healthz
// runtime params
@@ -204,6 +205,18 @@ func (c *Controller) Run() error {
tlsConfig := c.Config.CopyTLSConfig()
if tlsConfig != nil && tlsConfig.Key != "" && tlsConfig.Cert != "" {
// Create certificate reloader for automatic TLS certificate updates
certReloader, err := NewCertReloader(tlsConfig.Cert, tlsConfig.Key, c.Log)
if err != nil {
c.Log.Error().Err(err).Str("cert", tlsConfig.Cert).Str("key", tlsConfig.Key).
Msg("failed to load TLS certificates")
return err
}
// Store the CertReloader so it can be closed during shutdown
c.CertReloader = certReloader
// These are the same as the cipher suites in defaultCipherSuitesFIPS for TLS 1.2
// see https://cs.opensource.google/go/go/+/refs/tags/go1.24.9:src/crypto/tls/defaults.go;l=123
// Note: Order doesn't matter - Go 1.17+ automatically orders cipher suites based on
@@ -239,7 +252,8 @@ func (c *Controller) Run() error {
CipherSuites: cipherSuites,
CurvePreferences: curvePreferences,
// PreferServerCipherSuites is ignored in Go 1.17+ - Go automatically orders cipher suites
MinVersion: tls.VersionTLS12,
MinVersion: tls.VersionTLS12,
GetCertificate: certReloader.GetCertificateFunc(),
}
if tlsConfig.CACert != "" {
@@ -266,7 +280,8 @@ func (c *Controller) Run() error {
c.Healthz.Ready()
return server.ServeTLS(listener, tlsConfig.Cert, tlsConfig.Key)
// Pass empty strings to ServeTLS - certificates will be loaded via GetCertificate callback
return server.ServeTLS(listener, "", "")
}
c.Healthz.Ready()
@@ -484,6 +499,11 @@ func (c *Controller) StopBackgroundTasks() {
if c.HTPasswdWatcher != nil {
_ = c.HTPasswdWatcher.Close()
}
// Close CertReloader to prevent resource leaks
if c.CertReloader != nil {
_ = c.CertReloader.Close()
}
}
func (c *Controller) StartBackgroundTasks() {
+280
View File
@@ -0,0 +1,280 @@
package api
import (
"crypto/tls"
"os"
"path/filepath"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"zotregistry.dev/zot/v2/pkg/log"
)
const (
// certCheckCacheDuration is the minimum time between file stat checks when fsnotify is unavailable.
// This prevents excessive file system calls during high TLS handshake rates.
certCheckCacheDuration = 1 * time.Second
)
// CertReloader handles automatic reloading of TLS certificates without downtime.
// It monitors certificate and key files for changes and reloads them dynamically
// using a GetCertificate callback in tls.Config.
type CertReloader struct {
certMu sync.RWMutex
cert *tls.Certificate
certPath string
keyPath string
certMod time.Time
keyMod time.Time
log log.Logger
watcher *fsnotify.Watcher
reloadMu sync.Mutex // Prevents concurrent reload operations
lastCheck time.Time
checkCache time.Duration // Minimum time between file stat checks
stopWatcher chan struct{}
closeOnce sync.Once // Ensures Close() can be called multiple times safely
}
// NewCertReloader creates a new certificate reloader and loads the initial certificate.
// It starts an fsnotify watcher to monitor certificate file changes.
func NewCertReloader(certPath, keyPath string, logger log.Logger) (*CertReloader, error) {
reloader := &CertReloader{
certPath: certPath,
keyPath: keyPath,
log: logger,
checkCache: certCheckCacheDuration,
stopWatcher: make(chan struct{}),
}
if err := reloader.reload(); err != nil {
return nil, err
}
// Start fsnotify watcher in background
if err := reloader.startWatcher(); err != nil {
// Log warning but don't fail - we'll fall back to periodic checking
logger.Warn().Err(err).Msg("failed to start fsnotify watcher, falling back to periodic checking")
}
// NOTE: Do not add initialization that can fail after this point without ensuring
// the watcher is stopped (e.g., by calling Close on error), otherwise the
// watchLoop goroutine started by startWatcher could be leaked.
return reloader, nil
}
// Close stops the file watcher and releases resources.
// This method is safe to call multiple times.
func (cr *CertReloader) Close() error {
var err error
cr.closeOnce.Do(func() {
if cr.stopWatcher != nil {
close(cr.stopWatcher)
}
if cr.watcher != nil {
err = cr.watcher.Close()
}
})
return err
}
// startWatcher initializes the fsnotify watcher for certificate files.
func (cr *CertReloader) startWatcher() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
cr.watcher = watcher
// Watch the directory containing the certificate files
// This is more reliable than watching files directly, especially for atomic file updates
certDir := filepath.Dir(cr.certPath)
keyDir := filepath.Dir(cr.keyPath)
if err := watcher.Add(certDir); err != nil {
watcher.Close()
return err
}
// If cert and key are in different directories, watch both
if certDir != keyDir {
if err := watcher.Add(keyDir); err != nil {
watcher.Close()
return err
}
}
// Start goroutine to handle file system events
go cr.watchLoop()
return nil
}
// watchLoop handles file system events from fsnotify.
func (cr *CertReloader) watchLoop() {
for {
select {
case <-cr.stopWatcher:
return
case event, ok := <-cr.watcher.Events:
if !ok {
return
}
// Check if the event is for our certificate or key files
if event.Name == cr.certPath || event.Name == cr.keyPath {
// Only process write and create events
if event.Op&(fsnotify.Write|fsnotify.Create) != 0 {
cr.log.Debug().Str("file", event.Name).Str("op", event.Op.String()).
Msg("certificate file change detected")
// Try to reload the certificate
cr.tryReload()
}
}
case err, ok := <-cr.watcher.Errors:
if !ok {
return
}
cr.log.Warn().Err(err).Msg("fsnotify watcher error")
}
}
}
// tryReload attempts to reload certificates with proper concurrency control.
func (cr *CertReloader) tryReload() {
// Use mutex to ensure only one reload happens at a time
// This prevents race condition where multiple goroutines detect changes simultaneously
cr.reloadMu.Lock()
defer cr.reloadMu.Unlock()
if err := cr.reload(); err != nil {
cr.log.Warn().Err(err).Str("cert", cr.certPath).Str("key", cr.keyPath).
Msg("failed to reload TLS certificates")
} else {
cr.log.Info().Str("cert", cr.certPath).Str("key", cr.keyPath).
Msg("TLS certificates reloaded successfully")
}
}
// reload loads the certificate and key from disk and updates the internal certificate.
func (cr *CertReloader) reload() error {
// Get file modification times
certInfo, err := os.Stat(cr.certPath)
if err != nil {
return err
}
keyInfo, err := os.Stat(cr.keyPath)
if err != nil {
return err
}
certMod := certInfo.ModTime()
keyMod := keyInfo.ModTime()
// Load the certificate
newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
if err != nil {
return err
}
// Update the certificate and modification times
cr.certMu.Lock()
defer cr.certMu.Unlock()
cr.cert = &newCert
cr.certMod = certMod
cr.keyMod = keyMod
return nil
}
// maybeReload checks if the certificate files have been modified and reloads them if necessary.
// This is used as a fallback when fsnotify is not available or fails.
// Uses time-based caching to avoid excessive file system calls.
func (cr *CertReloader) maybeReload() error {
// Use write lock for both check and update to prevent race conditions
// While less efficient than RLock+Lock upgrade, this ensures only one goroutine
// updates lastCheck at a time, preventing multiple goroutines from bypassing
// the cache check simultaneously. Since we have a 1-second cache, this lock
// is acquired at most once per second, making the performance impact acceptable.
cr.certMu.Lock()
if time.Since(cr.lastCheck) < cr.checkCache {
// Recently checked, skip stat calls
cr.certMu.Unlock()
return nil
}
// Update last check time within the same critical section as the cache check
cr.lastCheck = time.Now()
cr.certMu.Unlock()
// Check cert file modification time
certInfo, err := os.Stat(cr.certPath)
if err != nil {
return err
}
keyInfo, err := os.Stat(cr.keyPath)
if err != nil {
return err
}
certMod := certInfo.ModTime()
keyMod := keyInfo.ModTime()
// Check if files have been modified
cr.certMu.RLock()
needsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod)
cr.certMu.RUnlock()
if needsReload {
// Use reloadMu to prevent concurrent reload operations
cr.reloadMu.Lock()
defer cr.reloadMu.Unlock()
// Double-check after acquiring lock - another goroutine might have already reloaded
cr.certMu.RLock()
stillNeedsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod)
cr.certMu.RUnlock()
if stillNeedsReload {
if err := cr.reload(); err != nil {
cr.log.Warn().Err(err).Str("cert", cr.certPath).Str("key", cr.keyPath).
Msg("failed to reload TLS certificates")
return err
}
cr.log.Info().Str("cert", cr.certPath).Str("key", cr.keyPath).
Msg("TLS certificates reloaded successfully")
}
}
return nil
}
// GetCertificateFunc returns a function that can be used as tls.Config.GetCertificate.
// This function checks for certificate updates on each TLS handshake and reloads if necessary.
// If fsnotify watcher is active, this only performs time-cached checks as a fallback.
func (cr *CertReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Try to reload the certificate if it has changed
// This is a fallback mechanism when fsnotify is not available
// Errors are logged but ignored to maintain availability with existing certificate
_ = cr.maybeReload()
cr.certMu.RLock()
defer cr.certMu.RUnlock()
return cr.cert, nil
}
}
+334
View File
@@ -0,0 +1,334 @@
package api_test
import (
"crypto/tls"
"crypto/x509"
"net/http"
"os"
"path/filepath"
"testing"
"time"
. "github.com/smartystreets/goconvey/convey"
"gopkg.in/resty.v1"
"zotregistry.dev/zot/v2/pkg/api"
"zotregistry.dev/zot/v2/pkg/api/config"
"zotregistry.dev/zot/v2/pkg/log"
test "zotregistry.dev/zot/v2/pkg/test/common"
tlsutils "zotregistry.dev/zot/v2/pkg/test/tls"
)
func TestTLSCertReload(t *testing.T) {
Convey("Test automatic TLS certificate reload", t, func() {
// Create temporary directory for certificates
tempDir := t.TempDir()
// Generate initial CA certificate
caOpts := &tlsutils.CertificateOptions{
CommonName: "Test CA",
NotAfter: time.Now().AddDate(1, 0, 0),
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
So(err, ShouldBeNil)
// Generate initial server certificate
serverCertPath := filepath.Join(tempDir, "server.cert")
serverKeyPath := filepath.Join(tempDir, "server.key")
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Server v1",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts)
So(err, ShouldBeNil)
// Create config with TLS
port := test.GetFreePort()
httpsURL := test.GetSecureBaseURL(port)
conf := config.New()
conf.HTTP.Address = "127.0.0.1"
conf.HTTP.Port = port
conf.HTTP.TLS = &config.TLSConfig{
Cert: serverCertPath,
Key: serverKeyPath,
}
conf.Storage.RootDirectory = t.TempDir()
ctlr := api.NewController(conf)
ctlr.Config.Storage.RootDirectory = t.TempDir()
cm := test.NewControllerManager(ctlr)
cm.StartAndWait(port)
defer cm.StopServer()
// Create client with CA certificate
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCertPEM)
httpClient := resty.New().
SetTLSClientConfig(&tls.Config{
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
}).
SetRedirectPolicy(resty.FlexibleRedirectPolicy(10))
// Verify initial connection works with HTTPS
resp, err := httpClient.R().Get(httpsURL + "/v2/")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
// Wait a moment to ensure file modification time will be different
time.Sleep(2 * time.Second)
// Generate new server certificate with different CommonName
serverOpts2 := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Server v2",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts2)
So(err, ShouldBeNil)
// Wait for certificate to be detected and reloaded
time.Sleep(1 * time.Second)
// Verify connection still works with new certificate
resp2, err := httpClient.R().Get(httpsURL + "/v2/")
So(err, ShouldBeNil)
So(resp2.StatusCode(), ShouldEqual, http.StatusOK)
})
}
func TestCertReloaderDirectly(t *testing.T) {
Convey("Test CertReloader functionality", t, func() {
tempDir := t.TempDir()
// Generate CA certificate
caOpts := &tlsutils.CertificateOptions{
CommonName: "Test CA",
NotAfter: time.Now().AddDate(1, 0, 0),
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
So(err, ShouldBeNil)
// Generate initial server certificate
certPath := filepath.Join(tempDir, "server.cert")
keyPath := filepath.Join(tempDir, "server.key")
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Initial Cert",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts)
So(err, ShouldBeNil)
Convey("NewCertReloader should load initial certificate", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
So(reloader, ShouldNotBeNil)
defer reloader.Close()
// Get certificate via callback
getCert := reloader.GetCertificateFunc()
cert, err := getCert(nil)
So(err, ShouldBeNil)
So(cert, ShouldNotBeNil)
})
Convey("GetCertificateFunc should reload when certificate changes", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
initialCert, err := getCert(nil)
So(err, ShouldBeNil)
So(initialCert, ShouldNotBeNil)
// Wait to ensure modification time will be different
time.Sleep(2 * time.Second)
// Generate new certificate
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Updated Cert",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
So(err, ShouldBeNil)
// Get certificate again - should reload automatically
updatedCert, err := getCert(nil)
So(err, ShouldBeNil)
So(updatedCert, ShouldNotBeNil)
// Certificates should be different (different leaf certificates)
initialLeaf, err := x509.ParseCertificate(initialCert.Certificate[0])
So(err, ShouldBeNil)
updatedLeaf, err := x509.ParseCertificate(updatedCert.Certificate[0])
So(err, ShouldBeNil)
// Common names should be different
So(initialLeaf.Subject.CommonName, ShouldEqual, "Initial Cert")
So(updatedLeaf.Subject.CommonName, ShouldEqual, "Updated Cert")
})
Convey("NewCertReloader should fail with invalid paths", func() {
_, err := api.NewCertReloader("/nonexistent/cert.pem", "/nonexistent/key.pem", log.NewTestLogger())
So(err, ShouldNotBeNil)
})
Convey("GetCertificateFunc should handle missing files gracefully", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
// Delete the certificate file
err = os.Remove(certPath)
So(err, ShouldBeNil)
// Should still return the old certificate (not fail)
cert, err := getCert(nil)
So(err, ShouldBeNil)
So(cert, ShouldNotBeNil)
})
Convey("GetCertificateFunc should handle only cert file modification", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
initialCert, err := getCert(nil)
So(err, ShouldBeNil)
So(initialCert, ShouldNotBeNil)
// Wait to ensure modification time will be different
time.Sleep(2 * time.Second)
// Modify only the cert file (touch it to update mtime)
// Generate new cert with same key
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Updated Cert Only",
NotAfter: time.Now().AddDate(1, 0, 0),
}
// Read the existing key
keyData, err := os.ReadFile(keyPath)
So(err, ShouldBeNil)
// Generate new cert using the existing key
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
So(err, ShouldBeNil)
// Restore original key
err = os.WriteFile(keyPath, keyData, 0o600)
So(err, ShouldBeNil)
// Get certificate again - should reload
updatedCert, err := getCert(nil)
So(err, ShouldBeNil)
So(updatedCert, ShouldNotBeNil)
})
Convey("GetCertificateFunc should handle concurrent access", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
// Launch multiple goroutines to access certificate concurrently
done := make(chan error, 10)
for range 10 {
go func() {
var lastErr error
for range 100 {
cert, err := getCert(nil)
if err != nil || cert == nil {
lastErr = err
break
}
}
done <- lastErr
}()
}
// Wait for all goroutines to complete and check for errors
for range 10 {
err := <-done
So(err, ShouldBeNil)
}
})
Convey("GetCertificateFunc should not reload if files haven't changed", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
// Get certificate multiple times
cert1, err := getCert(nil)
So(err, ShouldBeNil)
So(cert1, ShouldNotBeNil)
cert2, err := getCert(nil)
So(err, ShouldBeNil)
So(cert2, ShouldNotBeNil)
cert3, err := getCert(nil)
So(err, ShouldBeNil)
So(cert3, ShouldNotBeNil)
// All should return the same certificate instance (pointer equality)
So(cert1, ShouldEqual, cert2)
So(cert2, ShouldEqual, cert3)
})
Convey("GetCertificateFunc should reload when key file changes", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
initialCert, err := getCert(nil)
So(err, ShouldBeNil)
So(initialCert, ShouldNotBeNil)
// Wait to ensure modification time will be different
time.Sleep(2 * time.Second)
// Generate completely new cert and key
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "New Key Cert",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
So(err, ShouldBeNil)
// Get certificate again - should reload due to key change
updatedCert, err := getCert(nil)
So(err, ShouldBeNil)
So(updatedCert, ShouldNotBeNil)
// Verify certificates are different
initialLeaf, err := x509.ParseCertificate(initialCert.Certificate[0])
So(err, ShouldBeNil)
updatedLeaf, err := x509.ParseCertificate(updatedCert.Certificate[0])
So(err, ShouldBeNil)
So(initialLeaf.Subject.CommonName, ShouldEqual, "Initial Cert")
So(updatedLeaf.Subject.CommonName, ShouldEqual, "New Key Cert")
})
})
}