mirror of
https://github.com/project-zot/zot.git
synced 2026-06-18 05:28:07 +08:00
5f5c8ed586
- Define certCheckCacheDuration constant for better maintainability - Fix bash test syntax in tls_cert_reload.bats for command existence checks - Fix function call syntax without command substitution Co-authored-by: rchincha <45800463+rchincha@users.noreply.github.com>
267 lines
7.0 KiB
Go
267 lines
7.0 KiB
Go
package api
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/fsnotify/fsnotify"
|
|
|
|
"zotregistry.dev/zot/v2/pkg/log"
|
|
)
|
|
|
|
const (
|
|
// certCheckCacheDuration is the minimum time between file stat checks when fsnotify is unavailable.
|
|
// This prevents excessive file system calls during high TLS handshake rates.
|
|
certCheckCacheDuration = 1 * time.Second
|
|
)
|
|
|
|
// CertReloader handles automatic reloading of TLS certificates without downtime.
|
|
// It monitors certificate and key files for changes and reloads them dynamically
|
|
// using a GetCertificate callback in tls.Config.
|
|
type CertReloader struct {
|
|
certMu sync.RWMutex
|
|
cert *tls.Certificate
|
|
certPath string
|
|
keyPath string
|
|
certMod time.Time
|
|
keyMod time.Time
|
|
log log.Logger
|
|
watcher *fsnotify.Watcher
|
|
reloadMu sync.Mutex // Prevents concurrent reload operations
|
|
lastCheck time.Time
|
|
checkCache time.Duration // Minimum time between file stat checks
|
|
stopWatcher chan struct{}
|
|
}
|
|
|
|
// NewCertReloader creates a new certificate reloader and loads the initial certificate.
|
|
// It starts an fsnotify watcher to monitor certificate file changes.
|
|
func NewCertReloader(certPath, keyPath string, logger log.Logger) (*CertReloader, error) {
|
|
reloader := &CertReloader{
|
|
certPath: certPath,
|
|
keyPath: keyPath,
|
|
log: logger,
|
|
checkCache: certCheckCacheDuration,
|
|
stopWatcher: make(chan struct{}),
|
|
}
|
|
|
|
if err := reloader.reload(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Start fsnotify watcher in background
|
|
if err := reloader.startWatcher(); err != nil {
|
|
// Log warning but don't fail - we'll fall back to periodic checking
|
|
logger.Warn().Err(err).Msg("failed to start fsnotify watcher, falling back to periodic checking")
|
|
}
|
|
|
|
return reloader, nil
|
|
}
|
|
|
|
// Close stops the file watcher and releases resources.
|
|
func (cr *CertReloader) Close() error {
|
|
if cr.stopWatcher != nil {
|
|
close(cr.stopWatcher)
|
|
}
|
|
|
|
if cr.watcher != nil {
|
|
return cr.watcher.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// startWatcher initializes the fsnotify watcher for certificate files.
|
|
func (cr *CertReloader) startWatcher() error {
|
|
watcher, err := fsnotify.NewWatcher()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
cr.watcher = watcher
|
|
|
|
// Watch the directory containing the certificate files
|
|
// This is more reliable than watching files directly, especially for atomic file updates
|
|
certDir := filepath.Dir(cr.certPath)
|
|
keyDir := filepath.Dir(cr.keyPath)
|
|
|
|
if err := watcher.Add(certDir); err != nil {
|
|
return err
|
|
}
|
|
|
|
// If cert and key are in different directories, watch both
|
|
if certDir != keyDir {
|
|
if err := watcher.Add(keyDir); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Start goroutine to handle file system events
|
|
go cr.watchLoop()
|
|
|
|
return nil
|
|
}
|
|
|
|
// watchLoop handles file system events from fsnotify.
|
|
func (cr *CertReloader) watchLoop() {
|
|
for {
|
|
select {
|
|
case <-cr.stopWatcher:
|
|
return
|
|
case event, ok := <-cr.watcher.Events:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
// Check if the event is for our certificate or key files
|
|
if event.Name == cr.certPath || event.Name == cr.keyPath {
|
|
// Only process write and create events
|
|
if event.Op&(fsnotify.Write|fsnotify.Create) != 0 {
|
|
cr.log.Debug().Str("file", event.Name).Str("op", event.Op.String()).
|
|
Msg("certificate file change detected")
|
|
|
|
// Try to reload the certificate
|
|
cr.tryReload()
|
|
}
|
|
}
|
|
case err, ok := <-cr.watcher.Errors:
|
|
if !ok {
|
|
return
|
|
}
|
|
cr.log.Warn().Err(err).Msg("fsnotify watcher error")
|
|
}
|
|
}
|
|
}
|
|
|
|
// tryReload attempts to reload certificates with proper concurrency control.
|
|
func (cr *CertReloader) tryReload() {
|
|
// Use mutex to ensure only one reload happens at a time
|
|
// This prevents race condition where multiple goroutines detect changes simultaneously
|
|
cr.reloadMu.Lock()
|
|
defer cr.reloadMu.Unlock()
|
|
|
|
if err := cr.reload(); err != nil {
|
|
cr.log.Warn().Err(err).Str("cert", cr.certPath).Str("key", cr.keyPath).
|
|
Msg("failed to reload TLS certificates")
|
|
} else {
|
|
cr.log.Info().Str("cert", cr.certPath).Str("key", cr.keyPath).
|
|
Msg("TLS certificates reloaded successfully")
|
|
}
|
|
}
|
|
|
|
// reload loads the certificate and key from disk and updates the internal certificate.
|
|
func (cr *CertReloader) reload() error {
|
|
// Get file modification times
|
|
certInfo, err := os.Stat(cr.certPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
keyInfo, err := os.Stat(cr.keyPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
certMod := certInfo.ModTime()
|
|
keyMod := keyInfo.ModTime()
|
|
|
|
// Load the certificate
|
|
newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Update the certificate and modification times
|
|
cr.certMu.Lock()
|
|
defer cr.certMu.Unlock()
|
|
|
|
cr.cert = &newCert
|
|
cr.certMod = certMod
|
|
cr.keyMod = keyMod
|
|
|
|
return nil
|
|
}
|
|
|
|
// maybeReload checks if the certificate files have been modified and reloads them if necessary.
|
|
// This is used as a fallback when fsnotify is not available or fails.
|
|
// Uses time-based caching to avoid excessive file system calls.
|
|
func (cr *CertReloader) maybeReload() error {
|
|
// Use time-based cache to reduce frequency of stat calls
|
|
cr.certMu.RLock()
|
|
if time.Since(cr.lastCheck) < cr.checkCache {
|
|
// Recently checked, skip stat calls
|
|
cr.certMu.RUnlock()
|
|
|
|
return nil
|
|
}
|
|
cr.certMu.RUnlock()
|
|
|
|
// Update last check time
|
|
cr.certMu.Lock()
|
|
cr.lastCheck = time.Now()
|
|
cr.certMu.Unlock()
|
|
|
|
// Check cert file modification time
|
|
certInfo, err := os.Stat(cr.certPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
keyInfo, err := os.Stat(cr.keyPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
certMod := certInfo.ModTime()
|
|
keyMod := keyInfo.ModTime()
|
|
|
|
// Check if files have been modified
|
|
cr.certMu.RLock()
|
|
needsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod)
|
|
cr.certMu.RUnlock()
|
|
|
|
if needsReload {
|
|
// Use reloadMu to prevent concurrent reload operations
|
|
cr.reloadMu.Lock()
|
|
defer cr.reloadMu.Unlock()
|
|
|
|
// Double-check after acquiring lock - another goroutine might have already reloaded
|
|
cr.certMu.RLock()
|
|
stillNeedsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod)
|
|
cr.certMu.RUnlock()
|
|
|
|
if stillNeedsReload {
|
|
if err := cr.reload(); err != nil {
|
|
cr.log.Warn().Err(err).Str("cert", cr.certPath).Str("key", cr.keyPath).
|
|
Msg("failed to reload TLS certificates")
|
|
|
|
return err
|
|
}
|
|
|
|
cr.log.Info().Str("cert", cr.certPath).Str("key", cr.keyPath).
|
|
Msg("TLS certificates reloaded successfully")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetCertificateFunc returns a function that can be used as tls.Config.GetCertificate.
|
|
// This function checks for certificate updates on each TLS handshake and reloads if necessary.
|
|
// If fsnotify watcher is active, this only performs time-cached checks as a fallback.
|
|
func (cr *CertReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
// Try to reload the certificate if it has changed
|
|
// This is a fallback mechanism when fsnotify is not available
|
|
// Errors are logged but ignored to maintain availability with existing certificate
|
|
_ = cr.maybeReload()
|
|
|
|
cr.certMu.RLock()
|
|
defer cr.certMu.RUnlock()
|
|
|
|
return cr.cert, nil
|
|
}
|
|
}
|