mirror of
https://github.com/project-zot/zot.git
synced 2026-06-18 13:37:57 +08:00
Implement automatic TLS certificate reload feature
Co-authored-by: rchincha <45800463+rchincha@users.noreply.github.com>
This commit is contained in:
+13
-2
@@ -204,6 +204,15 @@ func (c *Controller) Run() error {
|
||||
|
||||
tlsConfig := c.Config.CopyTLSConfig()
|
||||
if tlsConfig != nil && tlsConfig.Key != "" && tlsConfig.Cert != "" {
|
||||
// Create certificate reloader for automatic TLS certificate updates
|
||||
certReloader, err := NewCertReloader(tlsConfig.Cert, tlsConfig.Key)
|
||||
if err != nil {
|
||||
c.Log.Error().Err(err).Str("cert", tlsConfig.Cert).Str("key", tlsConfig.Key).
|
||||
Msg("failed to load TLS certificates")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// These are the same as the cipher suites in defaultCipherSuitesFIPS for TLS 1.2
|
||||
// see https://cs.opensource.google/go/go/+/refs/tags/go1.24.9:src/crypto/tls/defaults.go;l=123
|
||||
// Note: Order doesn't matter - Go 1.17+ automatically orders cipher suites based on
|
||||
@@ -239,7 +248,8 @@ func (c *Controller) Run() error {
|
||||
CipherSuites: cipherSuites,
|
||||
CurvePreferences: curvePreferences,
|
||||
// PreferServerCipherSuites is ignored in Go 1.17+ - Go automatically orders cipher suites
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
GetCertificate: certReloader.GetCertificateFunc(),
|
||||
}
|
||||
|
||||
if tlsConfig.CACert != "" {
|
||||
@@ -266,7 +276,8 @@ func (c *Controller) Run() error {
|
||||
|
||||
c.Healthz.Ready()
|
||||
|
||||
return server.ServeTLS(listener, tlsConfig.Cert, tlsConfig.Key)
|
||||
// Pass empty strings to ServeTLS - certificates will be loaded via GetCertificate callback
|
||||
return server.ServeTLS(listener, "", "")
|
||||
}
|
||||
|
||||
c.Healthz.Ready()
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CertReloader handles automatic reloading of TLS certificates without downtime.
|
||||
// It monitors certificate and key files for changes and reloads them dynamically
|
||||
// using a GetCertificate callback in tls.Config.
|
||||
type CertReloader struct {
|
||||
certMu sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
certPath string
|
||||
keyPath string
|
||||
certMod time.Time
|
||||
keyMod time.Time
|
||||
}
|
||||
|
||||
// NewCertReloader creates a new certificate reloader and loads the initial certificate.
|
||||
func NewCertReloader(certPath, keyPath string) (*CertReloader, error) {
|
||||
reloader := &CertReloader{
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
}
|
||||
|
||||
if err := reloader.reload(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reloader, nil
|
||||
}
|
||||
|
||||
// reload loads the certificate and key from disk and updates the internal certificate.
|
||||
func (cr *CertReloader) reload() error {
|
||||
// Get file modification times
|
||||
certInfo, err := os.Stat(cr.certPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyInfo, err := os.Stat(cr.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certMod := certInfo.ModTime()
|
||||
keyMod := keyInfo.ModTime()
|
||||
|
||||
// Load the certificate
|
||||
newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the certificate and modification times
|
||||
cr.certMu.Lock()
|
||||
defer cr.certMu.Unlock()
|
||||
|
||||
cr.cert = &newCert
|
||||
cr.certMod = certMod
|
||||
cr.keyMod = keyMod
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// maybeReload checks if the certificate files have been modified and reloads them if necessary.
|
||||
func (cr *CertReloader) maybeReload() error {
|
||||
// Check cert file modification time
|
||||
certInfo, err := os.Stat(cr.certPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyInfo, err := os.Stat(cr.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certMod := certInfo.ModTime()
|
||||
keyMod := keyInfo.ModTime()
|
||||
|
||||
// Check if files have been modified
|
||||
cr.certMu.RLock()
|
||||
needsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod)
|
||||
cr.certMu.RUnlock()
|
||||
|
||||
if needsReload {
|
||||
return cr.reload()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCertificateFunc returns a function that can be used as tls.Config.GetCertificate.
|
||||
// This function checks for certificate updates on each TLS handshake and reloads if necessary.
|
||||
func (cr *CertReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
// Try to reload the certificate if it has changed
|
||||
// Ignore errors during reload attempts - keep using the existing certificate
|
||||
_ = cr.maybeReload()
|
||||
|
||||
cr.certMu.RLock()
|
||||
defer cr.certMu.RUnlock()
|
||||
|
||||
return cr.cert, nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
package api_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"gopkg.in/resty.v1"
|
||||
|
||||
"zotregistry.dev/zot/v2/pkg/api"
|
||||
"zotregistry.dev/zot/v2/pkg/api/config"
|
||||
test "zotregistry.dev/zot/v2/pkg/test/common"
|
||||
tlsutils "zotregistry.dev/zot/v2/pkg/test/tls"
|
||||
)
|
||||
|
||||
func TestTLSCertReload(t *testing.T) {
|
||||
Convey("Test automatic TLS certificate reload", t, func() {
|
||||
// Create temporary directory for certificates
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Generate initial CA certificate
|
||||
caOpts := &tlsutils.CertificateOptions{
|
||||
CommonName: "Test CA",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Generate initial server certificate
|
||||
serverCertPath := filepath.Join(tempDir, "server.cert")
|
||||
serverKeyPath := filepath.Join(tempDir, "server.key")
|
||||
serverOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "Server v1",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Create config with TLS
|
||||
port := test.GetFreePort()
|
||||
httpsURL := test.GetSecureBaseURL(port)
|
||||
|
||||
conf := config.New()
|
||||
conf.HTTP.Address = "127.0.0.1"
|
||||
conf.HTTP.Port = port
|
||||
conf.HTTP.TLS = &config.TLSConfig{
|
||||
Cert: serverCertPath,
|
||||
Key: serverKeyPath,
|
||||
}
|
||||
conf.Storage.RootDirectory = t.TempDir()
|
||||
|
||||
ctlr := api.NewController(conf)
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
|
||||
cm := test.NewControllerManager(ctlr)
|
||||
cm.StartAndWait(port)
|
||||
defer cm.StopServer()
|
||||
|
||||
// Create client with CA certificate
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCertPEM)
|
||||
|
||||
httpClient := resty.New().
|
||||
SetTLSClientConfig(&tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}).
|
||||
SetRedirectPolicy(resty.FlexibleRedirectPolicy(10))
|
||||
|
||||
// Verify initial connection works with HTTPS
|
||||
resp, err := httpClient.R().Get(httpsURL + "/v2/")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
// Wait a moment to ensure file modification time will be different
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Generate new server certificate with different CommonName
|
||||
serverOpts2 := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "Server v2",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts2)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Wait for certificate to be detected and reloaded
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Verify connection still works with new certificate
|
||||
resp2, err := httpClient.R().Get(httpsURL + "/v2/")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp2.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCertReloaderDirectly(t *testing.T) {
|
||||
Convey("Test CertReloader functionality", t, func() {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Generate CA certificate
|
||||
caOpts := &tlsutils.CertificateOptions{
|
||||
CommonName: "Test CA",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Generate initial server certificate
|
||||
certPath := filepath.Join(tempDir, "server.cert")
|
||||
keyPath := filepath.Join(tempDir, "server.key")
|
||||
serverOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "Initial Cert",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
Convey("NewCertReloader should load initial certificate", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath)
|
||||
So(err, ShouldBeNil)
|
||||
So(reloader, ShouldNotBeNil)
|
||||
|
||||
// Get certificate via callback
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
cert, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(cert, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetCertificateFunc should reload when certificate changes", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
initialCert, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(initialCert, ShouldNotBeNil)
|
||||
|
||||
// Wait to ensure modification time will be different
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Generate new certificate
|
||||
newServerOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "Updated Cert",
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Get certificate again - should reload automatically
|
||||
updatedCert, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(updatedCert, ShouldNotBeNil)
|
||||
|
||||
// Certificates should be different (different leaf certificates)
|
||||
initialLeaf, err := x509.ParseCertificate(initialCert.Certificate[0])
|
||||
So(err, ShouldBeNil)
|
||||
updatedLeaf, err := x509.ParseCertificate(updatedCert.Certificate[0])
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Common names should be different
|
||||
So(initialLeaf.Subject.CommonName, ShouldEqual, "Initial Cert")
|
||||
So(updatedLeaf.Subject.CommonName, ShouldEqual, "Updated Cert")
|
||||
})
|
||||
|
||||
Convey("NewCertReloader should fail with invalid paths", func() {
|
||||
_, err := api.NewCertReloader("/nonexistent/cert.pem", "/nonexistent/key.pem")
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetCertificateFunc should handle missing files gracefully", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
|
||||
// Delete the certificate file
|
||||
err = os.Remove(certPath)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Should still return the old certificate (not fail)
|
||||
cert, err := getCert(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(cert, ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user