feat(tls): automatically reload tls certs

Fixes issue #3747

Currently, zot requires a restart whenever tls certs change, which can
occur whenever there are tls cert rotation etc.

This PR checks if the tls certs have be modified and if so reloads them
without restarting zot.

Signed-off-by: Ramkumar Chinchani <rchincha.dev@gmail.com>
This commit is contained in:
Ramkumar Chinchani
2026-02-01 00:03:07 -08:00
parent b905528b6c
commit b94f3bafee
4 changed files with 817 additions and 2 deletions
+13 -2
View File
@@ -204,6 +204,15 @@ func (c *Controller) Run() error {
tlsConfig := c.Config.CopyTLSConfig()
if tlsConfig != nil && tlsConfig.Key != "" && tlsConfig.Cert != "" {
// Create certificate reloader for automatic TLS certificate updates
certReloader, err := NewCertReloader(tlsConfig.Cert, tlsConfig.Key, c.Log)
if err != nil {
c.Log.Error().Err(err).Str("cert", tlsConfig.Cert).Str("key", tlsConfig.Key).
Msg("failed to load TLS certificates")
return err
}
// These are the same as the cipher suites in defaultCipherSuitesFIPS for TLS 1.2
// see https://cs.opensource.google/go/go/+/refs/tags/go1.24.9:src/crypto/tls/defaults.go;l=123
// Note: Order doesn't matter - Go 1.17+ automatically orders cipher suites based on
@@ -239,7 +248,8 @@ func (c *Controller) Run() error {
CipherSuites: cipherSuites,
CurvePreferences: curvePreferences,
// PreferServerCipherSuites is ignored in Go 1.17+ - Go automatically orders cipher suites
MinVersion: tls.VersionTLS12,
MinVersion: tls.VersionTLS12,
GetCertificate: certReloader.GetCertificateFunc(),
}
if tlsConfig.CACert != "" {
@@ -266,7 +276,8 @@ func (c *Controller) Run() error {
c.Healthz.Ready()
return server.ServeTLS(listener, tlsConfig.Cert, tlsConfig.Key)
// Pass empty strings to ServeTLS - certificates will be loaded via GetCertificate callback
return server.ServeTLS(listener, "", "")
}
c.Healthz.Ready()
+266
View File
@@ -0,0 +1,266 @@
package api
import (
"crypto/tls"
"os"
"path/filepath"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"zotregistry.dev/zot/v2/pkg/log"
)
const (
// certCheckCacheDuration is the minimum time between file stat checks when fsnotify is unavailable.
// This prevents excessive file system calls during high TLS handshake rates.
certCheckCacheDuration = 1 * time.Second
)
// CertReloader handles automatic reloading of TLS certificates without downtime.
// It monitors certificate and key files for changes and reloads them dynamically
// using a GetCertificate callback in tls.Config.
type CertReloader struct {
certMu sync.RWMutex
cert *tls.Certificate
certPath string
keyPath string
certMod time.Time
keyMod time.Time
log log.Logger
watcher *fsnotify.Watcher
reloadMu sync.Mutex // Prevents concurrent reload operations
lastCheck time.Time
checkCache time.Duration // Minimum time between file stat checks
stopWatcher chan struct{}
}
// NewCertReloader creates a new certificate reloader and loads the initial certificate.
// It starts an fsnotify watcher to monitor certificate file changes.
func NewCertReloader(certPath, keyPath string, logger log.Logger) (*CertReloader, error) {
reloader := &CertReloader{
certPath: certPath,
keyPath: keyPath,
log: logger,
checkCache: certCheckCacheDuration,
stopWatcher: make(chan struct{}),
}
if err := reloader.reload(); err != nil {
return nil, err
}
// Start fsnotify watcher in background
if err := reloader.startWatcher(); err != nil {
// Log warning but don't fail - we'll fall back to periodic checking
logger.Warn().Err(err).Msg("failed to start fsnotify watcher, falling back to periodic checking")
}
return reloader, nil
}
// Close stops the file watcher and releases resources.
func (cr *CertReloader) Close() error {
if cr.stopWatcher != nil {
close(cr.stopWatcher)
}
if cr.watcher != nil {
return cr.watcher.Close()
}
return nil
}
// startWatcher initializes the fsnotify watcher for certificate files.
func (cr *CertReloader) startWatcher() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
cr.watcher = watcher
// Watch the directory containing the certificate files
// This is more reliable than watching files directly, especially for atomic file updates
certDir := filepath.Dir(cr.certPath)
keyDir := filepath.Dir(cr.keyPath)
if err := watcher.Add(certDir); err != nil {
return err
}
// If cert and key are in different directories, watch both
if certDir != keyDir {
if err := watcher.Add(keyDir); err != nil {
return err
}
}
// Start goroutine to handle file system events
go cr.watchLoop()
return nil
}
// watchLoop handles file system events from fsnotify.
func (cr *CertReloader) watchLoop() {
for {
select {
case <-cr.stopWatcher:
return
case event, ok := <-cr.watcher.Events:
if !ok {
return
}
// Check if the event is for our certificate or key files
if event.Name == cr.certPath || event.Name == cr.keyPath {
// Only process write and create events
if event.Op&(fsnotify.Write|fsnotify.Create) != 0 {
cr.log.Debug().Str("file", event.Name).Str("op", event.Op.String()).
Msg("certificate file change detected")
// Try to reload the certificate
cr.tryReload()
}
}
case err, ok := <-cr.watcher.Errors:
if !ok {
return
}
cr.log.Warn().Err(err).Msg("fsnotify watcher error")
}
}
}
// tryReload attempts to reload certificates with proper concurrency control.
func (cr *CertReloader) tryReload() {
// Use mutex to ensure only one reload happens at a time
// This prevents race condition where multiple goroutines detect changes simultaneously
cr.reloadMu.Lock()
defer cr.reloadMu.Unlock()
if err := cr.reload(); err != nil {
cr.log.Warn().Err(err).Str("cert", cr.certPath).Str("key", cr.keyPath).
Msg("failed to reload TLS certificates")
} else {
cr.log.Info().Str("cert", cr.certPath).Str("key", cr.keyPath).
Msg("TLS certificates reloaded successfully")
}
}
// reload loads the certificate and key from disk and updates the internal certificate.
func (cr *CertReloader) reload() error {
// Get file modification times
certInfo, err := os.Stat(cr.certPath)
if err != nil {
return err
}
keyInfo, err := os.Stat(cr.keyPath)
if err != nil {
return err
}
certMod := certInfo.ModTime()
keyMod := keyInfo.ModTime()
// Load the certificate
newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
if err != nil {
return err
}
// Update the certificate and modification times
cr.certMu.Lock()
defer cr.certMu.Unlock()
cr.cert = &newCert
cr.certMod = certMod
cr.keyMod = keyMod
return nil
}
// maybeReload checks if the certificate files have been modified and reloads them if necessary.
// This is used as a fallback when fsnotify is not available or fails.
// Uses time-based caching to avoid excessive file system calls.
func (cr *CertReloader) maybeReload() error {
// Use time-based cache to reduce frequency of stat calls
cr.certMu.RLock()
if time.Since(cr.lastCheck) < cr.checkCache {
// Recently checked, skip stat calls
cr.certMu.RUnlock()
return nil
}
cr.certMu.RUnlock()
// Update last check time
cr.certMu.Lock()
cr.lastCheck = time.Now()
cr.certMu.Unlock()
// Check cert file modification time
certInfo, err := os.Stat(cr.certPath)
if err != nil {
return err
}
keyInfo, err := os.Stat(cr.keyPath)
if err != nil {
return err
}
certMod := certInfo.ModTime()
keyMod := keyInfo.ModTime()
// Check if files have been modified
cr.certMu.RLock()
needsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod)
cr.certMu.RUnlock()
if needsReload {
// Use reloadMu to prevent concurrent reload operations
cr.reloadMu.Lock()
defer cr.reloadMu.Unlock()
// Double-check after acquiring lock - another goroutine might have already reloaded
cr.certMu.RLock()
stillNeedsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod)
cr.certMu.RUnlock()
if stillNeedsReload {
if err := cr.reload(); err != nil {
cr.log.Warn().Err(err).Str("cert", cr.certPath).Str("key", cr.keyPath).
Msg("failed to reload TLS certificates")
return err
}
cr.log.Info().Str("cert", cr.certPath).Str("key", cr.keyPath).
Msg("TLS certificates reloaded successfully")
}
}
return nil
}
// GetCertificateFunc returns a function that can be used as tls.Config.GetCertificate.
// This function checks for certificate updates on each TLS handshake and reloads if necessary.
// If fsnotify watcher is active, this only performs time-cached checks as a fallback.
func (cr *CertReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Try to reload the certificate if it has changed
// This is a fallback mechanism when fsnotify is not available
// Errors are logged but ignored to maintain availability with existing certificate
_ = cr.maybeReload()
cr.certMu.RLock()
defer cr.certMu.RUnlock()
return cr.cert, nil
}
}
+327
View File
@@ -0,0 +1,327 @@
package api_test
import (
"crypto/tls"
"crypto/x509"
"net/http"
"os"
"path/filepath"
"testing"
"time"
. "github.com/smartystreets/goconvey/convey"
"gopkg.in/resty.v1"
"zotregistry.dev/zot/v2/pkg/api"
"zotregistry.dev/zot/v2/pkg/api/config"
"zotregistry.dev/zot/v2/pkg/log"
test "zotregistry.dev/zot/v2/pkg/test/common"
tlsutils "zotregistry.dev/zot/v2/pkg/test/tls"
)
func TestTLSCertReload(t *testing.T) {
Convey("Test automatic TLS certificate reload", t, func() {
// Create temporary directory for certificates
tempDir := t.TempDir()
// Generate initial CA certificate
caOpts := &tlsutils.CertificateOptions{
CommonName: "Test CA",
NotAfter: time.Now().AddDate(1, 0, 0),
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
So(err, ShouldBeNil)
// Generate initial server certificate
serverCertPath := filepath.Join(tempDir, "server.cert")
serverKeyPath := filepath.Join(tempDir, "server.key")
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Server v1",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts)
So(err, ShouldBeNil)
// Create config with TLS
port := test.GetFreePort()
httpsURL := test.GetSecureBaseURL(port)
conf := config.New()
conf.HTTP.Address = "127.0.0.1"
conf.HTTP.Port = port
conf.HTTP.TLS = &config.TLSConfig{
Cert: serverCertPath,
Key: serverKeyPath,
}
conf.Storage.RootDirectory = t.TempDir()
ctlr := api.NewController(conf)
ctlr.Config.Storage.RootDirectory = t.TempDir()
cm := test.NewControllerManager(ctlr)
cm.StartAndWait(port)
defer cm.StopServer()
// Create client with CA certificate
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCertPEM)
httpClient := resty.New().
SetTLSClientConfig(&tls.Config{
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
}).
SetRedirectPolicy(resty.FlexibleRedirectPolicy(10))
// Verify initial connection works with HTTPS
resp, err := httpClient.R().Get(httpsURL + "/v2/")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
// Wait a moment to ensure file modification time will be different
time.Sleep(2 * time.Second)
// Generate new server certificate with different CommonName
serverOpts2 := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Server v2",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts2)
So(err, ShouldBeNil)
// Wait for certificate to be detected and reloaded
time.Sleep(1 * time.Second)
// Verify connection still works with new certificate
resp2, err := httpClient.R().Get(httpsURL + "/v2/")
So(err, ShouldBeNil)
So(resp2.StatusCode(), ShouldEqual, http.StatusOK)
})
}
func TestCertReloaderDirectly(t *testing.T) {
Convey("Test CertReloader functionality", t, func() {
tempDir := t.TempDir()
// Generate CA certificate
caOpts := &tlsutils.CertificateOptions{
CommonName: "Test CA",
NotAfter: time.Now().AddDate(1, 0, 0),
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
So(err, ShouldBeNil)
// Generate initial server certificate
certPath := filepath.Join(tempDir, "server.cert")
keyPath := filepath.Join(tempDir, "server.key")
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Initial Cert",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts)
So(err, ShouldBeNil)
Convey("NewCertReloader should load initial certificate", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
So(reloader, ShouldNotBeNil)
// Get certificate via callback
getCert := reloader.GetCertificateFunc()
cert, err := getCert(nil)
So(err, ShouldBeNil)
So(cert, ShouldNotBeNil)
})
Convey("GetCertificateFunc should reload when certificate changes", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
getCert := reloader.GetCertificateFunc()
initialCert, err := getCert(nil)
So(err, ShouldBeNil)
So(initialCert, ShouldNotBeNil)
// Wait to ensure modification time will be different
time.Sleep(2 * time.Second)
// Generate new certificate
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Updated Cert",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
So(err, ShouldBeNil)
// Get certificate again - should reload automatically
updatedCert, err := getCert(nil)
So(err, ShouldBeNil)
So(updatedCert, ShouldNotBeNil)
// Certificates should be different (different leaf certificates)
initialLeaf, err := x509.ParseCertificate(initialCert.Certificate[0])
So(err, ShouldBeNil)
updatedLeaf, err := x509.ParseCertificate(updatedCert.Certificate[0])
So(err, ShouldBeNil)
// Common names should be different
So(initialLeaf.Subject.CommonName, ShouldEqual, "Initial Cert")
So(updatedLeaf.Subject.CommonName, ShouldEqual, "Updated Cert")
})
Convey("NewCertReloader should fail with invalid paths", func() {
_, err := api.NewCertReloader("/nonexistent/cert.pem", "/nonexistent/key.pem", log.NewTestLogger())
So(err, ShouldNotBeNil)
})
Convey("GetCertificateFunc should handle missing files gracefully", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
getCert := reloader.GetCertificateFunc()
// Delete the certificate file
err = os.Remove(certPath)
So(err, ShouldBeNil)
// Should still return the old certificate (not fail)
cert, err := getCert(nil)
So(err, ShouldBeNil)
So(cert, ShouldNotBeNil)
})
Convey("GetCertificateFunc should handle only cert file modification", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
getCert := reloader.GetCertificateFunc()
initialCert, err := getCert(nil)
So(err, ShouldBeNil)
So(initialCert, ShouldNotBeNil)
// Wait to ensure modification time will be different
time.Sleep(2 * time.Second)
// Modify only the cert file (touch it to update mtime)
// Generate new cert with same key
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Updated Cert Only",
NotAfter: time.Now().AddDate(1, 0, 0),
}
// Read the existing key
keyData, err := os.ReadFile(keyPath)
So(err, ShouldBeNil)
// Generate new cert using the existing key
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
So(err, ShouldBeNil)
// Restore original key
err = os.WriteFile(keyPath, keyData, 0o600)
So(err, ShouldBeNil)
// Get certificate again - should reload
updatedCert, err := getCert(nil)
So(err, ShouldBeNil)
So(updatedCert, ShouldNotBeNil)
})
Convey("GetCertificateFunc should handle concurrent access", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
getCert := reloader.GetCertificateFunc()
// Launch multiple goroutines to access certificate concurrently
done := make(chan error, 10)
for range 10 {
go func() {
var lastErr error
for range 100 {
cert, err := getCert(nil)
if err != nil || cert == nil {
lastErr = err
break
}
}
done <- lastErr
}()
}
// Wait for all goroutines to complete and check for errors
for range 10 {
err := <-done
So(err, ShouldBeNil)
}
})
Convey("GetCertificateFunc should not reload if files haven't changed", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
getCert := reloader.GetCertificateFunc()
// Get certificate multiple times
cert1, err := getCert(nil)
So(err, ShouldBeNil)
So(cert1, ShouldNotBeNil)
cert2, err := getCert(nil)
So(err, ShouldBeNil)
So(cert2, ShouldNotBeNil)
cert3, err := getCert(nil)
So(err, ShouldBeNil)
So(cert3, ShouldNotBeNil)
// All should return the same certificate instance (pointer equality)
So(cert1, ShouldEqual, cert2)
So(cert2, ShouldEqual, cert3)
})
Convey("GetCertificateFunc should reload when key file changes", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
getCert := reloader.GetCertificateFunc()
initialCert, err := getCert(nil)
So(err, ShouldBeNil)
So(initialCert, ShouldNotBeNil)
// Wait to ensure modification time will be different
time.Sleep(2 * time.Second)
// Generate completely new cert and key
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "New Key Cert",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
So(err, ShouldBeNil)
// Get certificate again - should reload due to key change
updatedCert, err := getCert(nil)
So(err, ShouldBeNil)
So(updatedCert, ShouldNotBeNil)
// Verify certificates are different
initialLeaf, err := x509.ParseCertificate(initialCert.Certificate[0])
So(err, ShouldBeNil)
updatedLeaf, err := x509.ParseCertificate(updatedCert.Certificate[0])
So(err, ShouldBeNil)
So(initialLeaf.Subject.CommonName, ShouldEqual, "Initial Cert")
So(updatedLeaf.Subject.CommonName, ShouldEqual, "New Key Cert")
})
})
}