mirror of
https://github.com/project-zot/zot.git
synced 2026-06-18 05:28:07 +08:00
feat(tls): automatically reload tls certs
Fixes issue #3747 Currently, zot requires a restart whenever tls certs change, which can occur whenever there are tls cert rotation etc. This PR checks if the tls certs have be modified and if so reloads them without restarting zot. Signed-off-by: Ramkumar Chinchani <rchincha.dev@gmail.com>
This commit is contained in:
+22
-2
@@ -54,6 +54,7 @@ type Controller struct {
|
||||
HTPasswd *HTPasswd
|
||||
HTPasswdWatcher *HTPasswdWatcher
|
||||
LDAPClient *LDAPClient
|
||||
CertReloader *CertReloader
|
||||
taskScheduler *scheduler.Scheduler
|
||||
Healthz *common.Healthz
|
||||
// runtime params
|
||||
@@ -204,6 +205,18 @@ func (c *Controller) Run() error {
|
||||
|
||||
tlsConfig := c.Config.CopyTLSConfig()
|
||||
if tlsConfig != nil && tlsConfig.Key != "" && tlsConfig.Cert != "" {
|
||||
// Create certificate reloader for automatic TLS certificate updates
|
||||
certReloader, err := NewCertReloader(tlsConfig.Cert, tlsConfig.Key, c.Log)
|
||||
if err != nil {
|
||||
c.Log.Error().Err(err).Str("cert", tlsConfig.Cert).Str("key", tlsConfig.Key).
|
||||
Msg("failed to load TLS certificates")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the CertReloader so it can be closed during shutdown
|
||||
c.CertReloader = certReloader
|
||||
|
||||
// These are the same as the cipher suites in defaultCipherSuitesFIPS for TLS 1.2
|
||||
// see https://cs.opensource.google/go/go/+/refs/tags/go1.24.9:src/crypto/tls/defaults.go;l=123
|
||||
// Note: Order doesn't matter - Go 1.17+ automatically orders cipher suites based on
|
||||
@@ -239,7 +252,8 @@ func (c *Controller) Run() error {
|
||||
CipherSuites: cipherSuites,
|
||||
CurvePreferences: curvePreferences,
|
||||
// PreferServerCipherSuites is ignored in Go 1.17+ - Go automatically orders cipher suites
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
GetCertificate: certReloader.GetCertificateFunc(),
|
||||
}
|
||||
|
||||
if tlsConfig.CACert != "" {
|
||||
@@ -266,7 +280,8 @@ func (c *Controller) Run() error {
|
||||
|
||||
c.Healthz.Ready()
|
||||
|
||||
return server.ServeTLS(listener, tlsConfig.Cert, tlsConfig.Key)
|
||||
// Pass empty strings to ServeTLS - certificates will be loaded via GetCertificate callback
|
||||
return server.ServeTLS(listener, "", "")
|
||||
}
|
||||
|
||||
c.Healthz.Ready()
|
||||
@@ -484,6 +499,11 @@ func (c *Controller) StopBackgroundTasks() {
|
||||
if c.HTPasswdWatcher != nil {
|
||||
_ = c.HTPasswdWatcher.Close()
|
||||
}
|
||||
|
||||
// Close CertReloader to prevent resource leaks
|
||||
if c.CertReloader != nil {
|
||||
_ = c.CertReloader.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) StartBackgroundTasks() {
|
||||
|
||||
@@ -0,0 +1,280 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
|
||||
"zotregistry.dev/zot/v2/pkg/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// certCheckCacheDuration is the minimum time between file stat checks when fsnotify is unavailable.
|
||||
// This prevents excessive file system calls during high TLS handshake rates.
|
||||
certCheckCacheDuration = 1 * time.Second
|
||||
)
|
||||
|
||||
// CertReloader handles automatic reloading of TLS certificates without downtime.
|
||||
// It monitors certificate and key files for changes and reloads them dynamically
|
||||
// using a GetCertificate callback in tls.Config.
|
||||
type CertReloader struct {
|
||||
certMu sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
certPath string
|
||||
keyPath string
|
||||
certMod time.Time
|
||||
keyMod time.Time
|
||||
log log.Logger
|
||||
watcher *fsnotify.Watcher
|
||||
reloadMu sync.Mutex // Prevents concurrent reload operations
|
||||
lastCheck time.Time
|
||||
checkCache time.Duration // Minimum time between file stat checks
|
||||
stopWatcher chan struct{}
|
||||
closeOnce sync.Once // Ensures Close() can be called multiple times safely
|
||||
}
|
||||
|
||||
// NewCertReloader creates a new certificate reloader and loads the initial certificate.
|
||||
// It starts an fsnotify watcher to monitor certificate file changes.
|
||||
func NewCertReloader(certPath, keyPath string, logger log.Logger) (*CertReloader, error) {
|
||||
reloader := &CertReloader{
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
log: logger,
|
||||
checkCache: certCheckCacheDuration,
|
||||
stopWatcher: make(chan struct{}),
|
||||
}
|
||||
|
||||
if err := reloader.reload(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start fsnotify watcher in background
|
||||
if err := reloader.startWatcher(); err != nil {
|
||||
// Log warning but don't fail - we'll fall back to periodic checking
|
||||
logger.Warn().Err(err).Msg("failed to start fsnotify watcher, falling back to periodic checking")
|
||||
}
|
||||
|
||||
// NOTE: Do not add initialization that can fail after this point without ensuring
|
||||
// the watcher is stopped (e.g., by calling Close on error), otherwise the
|
||||
// watchLoop goroutine started by startWatcher could be leaked.
|
||||
|
||||
return reloader, nil
|
||||
}
|
||||
|
||||
// Close stops the file watcher and releases resources.
|
||||
// This method is safe to call multiple times.
|
||||
func (cr *CertReloader) Close() error {
|
||||
var err error
|
||||
cr.closeOnce.Do(func() {
|
||||
if cr.stopWatcher != nil {
|
||||
close(cr.stopWatcher)
|
||||
}
|
||||
|
||||
if cr.watcher != nil {
|
||||
err = cr.watcher.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// startWatcher initializes the fsnotify watcher for certificate files.
|
||||
func (cr *CertReloader) startWatcher() error {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cr.watcher = watcher
|
||||
|
||||
// Watch the directory containing the certificate files
|
||||
// This is more reliable than watching files directly, especially for atomic file updates
|
||||
certDir := filepath.Dir(cr.certPath)
|
||||
keyDir := filepath.Dir(cr.keyPath)
|
||||
|
||||
if err := watcher.Add(certDir); err != nil {
|
||||
watcher.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// If cert and key are in different directories, watch both
|
||||
if certDir != keyDir {
|
||||
if err := watcher.Add(keyDir); err != nil {
|
||||
watcher.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Start goroutine to handle file system events
|
||||
go cr.watchLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// watchLoop handles file system events from fsnotify.
|
||||
func (cr *CertReloader) watchLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-cr.stopWatcher:
|
||||
return
|
||||
case event, ok := <-cr.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the event is for our certificate or key files
|
||||
if event.Name == cr.certPath || event.Name == cr.keyPath {
|
||||
// Only process write and create events
|
||||
if event.Op&(fsnotify.Write|fsnotify.Create) != 0 {
|
||||
cr.log.Debug().Str("file", event.Name).Str("op", event.Op.String()).
|
||||
Msg("certificate file change detected")
|
||||
|
||||
// Try to reload the certificate
|
||||
cr.tryReload()
|
||||
}
|
||||
}
|
||||
case err, ok := <-cr.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
cr.log.Warn().Err(err).Msg("fsnotify watcher error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryReload attempts to reload certificates with proper concurrency control.
|
||||
func (cr *CertReloader) tryReload() {
|
||||
// Use mutex to ensure only one reload happens at a time
|
||||
// This prevents race condition where multiple goroutines detect changes simultaneously
|
||||
cr.reloadMu.Lock()
|
||||
defer cr.reloadMu.Unlock()
|
||||
|
||||
if err := cr.reload(); err != nil {
|
||||
cr.log.Warn().Err(err).Str("cert", cr.certPath).Str("key", cr.keyPath).
|
||||
Msg("failed to reload TLS certificates")
|
||||
} else {
|
||||
cr.log.Info().Str("cert", cr.certPath).Str("key", cr.keyPath).
|
||||
Msg("TLS certificates reloaded successfully")
|
||||
}
|
||||
}
|
||||
|
||||
// reload loads the certificate and key from disk and updates the internal certificate.
|
||||
func (cr *CertReloader) reload() error {
|
||||
// Get file modification times
|
||||
certInfo, err := os.Stat(cr.certPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyInfo, err := os.Stat(cr.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certMod := certInfo.ModTime()
|
||||
keyMod := keyInfo.ModTime()
|
||||
|
||||
// Load the certificate
|
||||
newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the certificate and modification times
|
||||
cr.certMu.Lock()
|
||||
defer cr.certMu.Unlock()
|
||||
|
||||
cr.cert = &newCert
|
||||
cr.certMod = certMod
|
||||
cr.keyMod = keyMod
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// maybeReload checks if the certificate files have been modified and reloads them if necessary.
|
||||
// This is used as a fallback when fsnotify is not available or fails.
|
||||
// Uses time-based caching to avoid excessive file system calls.
|
||||
func (cr *CertReloader) maybeReload() error {
|
||||
// Use write lock for both check and update to prevent race conditions
|
||||
// While less efficient than RLock+Lock upgrade, this ensures only one goroutine
|
||||
// updates lastCheck at a time, preventing multiple goroutines from bypassing
|
||||
// the cache check simultaneously. Since we have a 1-second cache, this lock
|
||||
// is acquired at most once per second, making the performance impact acceptable.
|
||||
cr.certMu.Lock()
|
||||
if time.Since(cr.lastCheck) < cr.checkCache {
|
||||
// Recently checked, skip stat calls
|
||||
cr.certMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
// Update last check time within the same critical section as the cache check
|
||||
cr.lastCheck = time.Now()
|
||||
cr.certMu.Unlock()
|
||||
|
||||
// Check cert file modification time
|
||||
certInfo, err := os.Stat(cr.certPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyInfo, err := os.Stat(cr.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certMod := certInfo.ModTime()
|
||||
keyMod := keyInfo.ModTime()
|
||||
|
||||
// Check if files have been modified
|
||||
cr.certMu.RLock()
|
||||
needsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod)
|
||||
cr.certMu.RUnlock()
|
||||
|
||||
if needsReload {
|
||||
// Use reloadMu to prevent concurrent reload operations
|
||||
cr.reloadMu.Lock()
|
||||
defer cr.reloadMu.Unlock()
|
||||
|
||||
// Double-check after acquiring lock - another goroutine might have already reloaded
|
||||
cr.certMu.RLock()
|
||||
stillNeedsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod)
|
||||
cr.certMu.RUnlock()
|
||||
|
||||
if stillNeedsReload {
|
||||
if err := cr.reload(); err != nil {
|
||||
cr.log.Warn().Err(err).Str("cert", cr.certPath).Str("key", cr.keyPath).
|
||||
Msg("failed to reload TLS certificates")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
cr.log.Info().Str("cert", cr.certPath).Str("key", cr.keyPath).
|
||||
Msg("TLS certificates reloaded successfully")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCertificateFunc returns a function that can be used as tls.Config.GetCertificate.
|
||||
// This function checks for certificate updates on each TLS handshake and reloads if necessary.
|
||||
// If fsnotify watcher is active, this only performs time-cached checks as a fallback.
|
||||
func (cr *CertReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
// Try to reload the certificate if it has changed
|
||||
// This is a fallback mechanism when fsnotify is not available
|
||||
// Errors are logged but ignored to maintain availability with existing certificate
|
||||
_ = cr.maybeReload()
|
||||
|
||||
cr.certMu.RLock()
|
||||
defer cr.certMu.RUnlock()
|
||||
|
||||
return cr.cert, nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,334 @@
|
||||
package api_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"gopkg.in/resty.v1"
|
||||
|
||||
"zotregistry.dev/zot/v2/pkg/api"
|
||||
"zotregistry.dev/zot/v2/pkg/api/config"
|
||||
"zotregistry.dev/zot/v2/pkg/log"
|
||||
test "zotregistry.dev/zot/v2/pkg/test/common"
|
||||
tlsutils "zotregistry.dev/zot/v2/pkg/test/tls"
|
||||
)
|
||||
|
||||
func TestTLSCertReload(t *testing.T) {
|
||||
Convey("Test automatic TLS certificate reload", t, func() {
|
||||
// Create temporary directory for certificates
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Generate initial CA certificate
|
||||
caOpts := &tlsutils.CertificateOptions{
|
||||
CommonName: "Test CA",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Generate initial server certificate
|
||||
serverCertPath := filepath.Join(tempDir, "server.cert")
|
||||
serverKeyPath := filepath.Join(tempDir, "server.key")
|
||||
serverOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "Server v1",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Create config with TLS
|
||||
port := test.GetFreePort()
|
||||
httpsURL := test.GetSecureBaseURL(port)
|
||||
|
||||
conf := config.New()
|
||||
conf.HTTP.Address = "127.0.0.1"
|
||||
conf.HTTP.Port = port
|
||||
conf.HTTP.TLS = &config.TLSConfig{
|
||||
Cert: serverCertPath,
|
||||
Key: serverKeyPath,
|
||||
}
|
||||
conf.Storage.RootDirectory = t.TempDir()
|
||||
|
||||
ctlr := api.NewController(conf)
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
|
||||
cm := test.NewControllerManager(ctlr)
|
||||
cm.StartAndWait(port)
|
||||
defer cm.StopServer()
|
||||
|
||||
// Create client with CA certificate
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCertPEM)
|
||||
|
||||
httpClient := resty.New().
|
||||
SetTLSClientConfig(&tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}).
|
||||
SetRedirectPolicy(resty.FlexibleRedirectPolicy(10))
|
||||
|
||||
// Verify initial connection works with HTTPS
|
||||
resp, err := httpClient.R().Get(httpsURL + "/v2/")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
// Wait a moment to ensure file modification time will be different
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Generate new server certificate with different CommonName
|
||||
serverOpts2 := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "Server v2",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts2)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Wait for certificate to be detected and reloaded
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Verify connection still works with new certificate
|
||||
resp2, err := httpClient.R().Get(httpsURL + "/v2/")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp2.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCertReloaderDirectly(t *testing.T) {
|
||||
Convey("Test CertReloader functionality", t, func() {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Generate CA certificate
|
||||
caOpts := &tlsutils.CertificateOptions{
|
||||
CommonName: "Test CA",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Generate initial server certificate
|
||||
certPath := filepath.Join(tempDir, "server.cert")
|
||||
keyPath := filepath.Join(tempDir, "server.key")
|
||||
serverOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "Initial Cert",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
Convey("NewCertReloader should load initial certificate", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
So(reloader, ShouldNotBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
// Get certificate via callback
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
cert, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(cert, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetCertificateFunc should reload when certificate changes", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
initialCert, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(initialCert, ShouldNotBeNil)
|
||||
|
||||
// Wait to ensure modification time will be different
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Generate new certificate
|
||||
newServerOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "Updated Cert",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Get certificate again - should reload automatically
|
||||
updatedCert, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(updatedCert, ShouldNotBeNil)
|
||||
|
||||
// Certificates should be different (different leaf certificates)
|
||||
initialLeaf, err := x509.ParseCertificate(initialCert.Certificate[0])
|
||||
So(err, ShouldBeNil)
|
||||
updatedLeaf, err := x509.ParseCertificate(updatedCert.Certificate[0])
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Common names should be different
|
||||
So(initialLeaf.Subject.CommonName, ShouldEqual, "Initial Cert")
|
||||
So(updatedLeaf.Subject.CommonName, ShouldEqual, "Updated Cert")
|
||||
})
|
||||
|
||||
Convey("NewCertReloader should fail with invalid paths", func() {
|
||||
_, err := api.NewCertReloader("/nonexistent/cert.pem", "/nonexistent/key.pem", log.NewTestLogger())
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetCertificateFunc should handle missing files gracefully", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
|
||||
// Delete the certificate file
|
||||
err = os.Remove(certPath)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Should still return the old certificate (not fail)
|
||||
cert, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(cert, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetCertificateFunc should handle only cert file modification", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
initialCert, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(initialCert, ShouldNotBeNil)
|
||||
|
||||
// Wait to ensure modification time will be different
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Modify only the cert file (touch it to update mtime)
|
||||
// Generate new cert with same key
|
||||
newServerOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "Updated Cert Only",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
// Read the existing key
|
||||
keyData, err := os.ReadFile(keyPath)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Generate new cert using the existing key
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Restore original key
|
||||
err = os.WriteFile(keyPath, keyData, 0o600)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Get certificate again - should reload
|
||||
updatedCert, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(updatedCert, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetCertificateFunc should handle concurrent access", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
|
||||
// Launch multiple goroutines to access certificate concurrently
|
||||
done := make(chan error, 10)
|
||||
|
||||
for range 10 {
|
||||
go func() {
|
||||
var lastErr error
|
||||
|
||||
for range 100 {
|
||||
cert, err := getCert(nil)
|
||||
if err != nil || cert == nil {
|
||||
lastErr = err
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
done <- lastErr
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete and check for errors
|
||||
for range 10 {
|
||||
err := <-done
|
||||
So(err, ShouldBeNil)
|
||||
}
|
||||
})
|
||||
|
||||
Convey("GetCertificateFunc should not reload if files haven't changed", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
|
||||
// Get certificate multiple times
|
||||
cert1, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(cert1, ShouldNotBeNil)
|
||||
|
||||
cert2, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(cert2, ShouldNotBeNil)
|
||||
|
||||
cert3, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(cert3, ShouldNotBeNil)
|
||||
|
||||
// All should return the same certificate instance (pointer equality)
|
||||
So(cert1, ShouldEqual, cert2)
|
||||
So(cert2, ShouldEqual, cert3)
|
||||
})
|
||||
|
||||
Convey("GetCertificateFunc should reload when key file changes", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
initialCert, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(initialCert, ShouldNotBeNil)
|
||||
|
||||
// Wait to ensure modification time will be different
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Generate completely new cert and key
|
||||
newServerOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "New Key Cert",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Get certificate again - should reload due to key change
|
||||
updatedCert, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(updatedCert, ShouldNotBeNil)
|
||||
|
||||
// Verify certificates are different
|
||||
initialLeaf, err := x509.ParseCertificate(initialCert.Certificate[0])
|
||||
So(err, ShouldBeNil)
|
||||
updatedLeaf, err := x509.ParseCertificate(updatedCert.Certificate[0])
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
So(initialLeaf.Subject.CommonName, ShouldEqual, "Initial Cert")
|
||||
So(updatedLeaf.Subject.CommonName, ShouldEqual, "New Key Cert")
|
||||
})
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user