From 3d17ba6dbd7edcfdc573f73d7125285f77259b62 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 30 Jan 2026 05:45:36 +0000 Subject: [PATCH] Implement automatic TLS certificate reload feature Co-authored-by: rchincha <45800463+rchincha@users.noreply.github.com> --- pkg/api/controller.go | 15 ++- pkg/api/tlscert.go | 110 ++++++++++++++++++++++ pkg/api/tlscert_test.go | 197 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 320 insertions(+), 2 deletions(-) create mode 100644 pkg/api/tlscert.go create mode 100644 pkg/api/tlscert_test.go diff --git a/pkg/api/controller.go b/pkg/api/controller.go index c20a206e..85ba59bf 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -204,6 +204,15 @@ func (c *Controller) Run() error { tlsConfig := c.Config.CopyTLSConfig() if tlsConfig != nil && tlsConfig.Key != "" && tlsConfig.Cert != "" { + // Create certificate reloader for automatic TLS certificate updates + certReloader, err := NewCertReloader(tlsConfig.Cert, tlsConfig.Key) + if err != nil { + c.Log.Error().Err(err).Str("cert", tlsConfig.Cert).Str("key", tlsConfig.Key). + Msg("failed to load TLS certificates") + + return err + } + // These are the same as the cipher suites in defaultCipherSuitesFIPS for TLS 1.2 // see https://cs.opensource.google/go/go/+/refs/tags/go1.24.9:src/crypto/tls/defaults.go;l=123 // Note: Order doesn't matter - Go 1.17+ automatically orders cipher suites based on @@ -239,7 +248,8 @@ func (c *Controller) Run() error { CipherSuites: cipherSuites, CurvePreferences: curvePreferences, // PreferServerCipherSuites is ignored in Go 1.17+ - Go automatically orders cipher suites - MinVersion: tls.VersionTLS12, + MinVersion: tls.VersionTLS12, + GetCertificate: certReloader.GetCertificateFunc(), } if tlsConfig.CACert != "" { @@ -266,7 +276,8 @@ func (c *Controller) Run() error { c.Healthz.Ready() - return server.ServeTLS(listener, tlsConfig.Cert, tlsConfig.Key) + // Pass empty strings to ServeTLS - certificates will be loaded via GetCertificate callback + return server.ServeTLS(listener, "", "") } c.Healthz.Ready() diff --git a/pkg/api/tlscert.go b/pkg/api/tlscert.go new file mode 100644 index 00000000..ebdac3e9 --- /dev/null +++ b/pkg/api/tlscert.go @@ -0,0 +1,110 @@ +package api + +import ( + "crypto/tls" + "os" + "sync" + "time" +) + +// CertReloader handles automatic reloading of TLS certificates without downtime. +// It monitors certificate and key files for changes and reloads them dynamically +// using a GetCertificate callback in tls.Config. +type CertReloader struct { + certMu sync.RWMutex + cert *tls.Certificate + certPath string + keyPath string + certMod time.Time + keyMod time.Time +} + +// NewCertReloader creates a new certificate reloader and loads the initial certificate. +func NewCertReloader(certPath, keyPath string) (*CertReloader, error) { + reloader := &CertReloader{ + certPath: certPath, + keyPath: keyPath, + } + + if err := reloader.reload(); err != nil { + return nil, err + } + + return reloader, nil +} + +// reload loads the certificate and key from disk and updates the internal certificate. +func (cr *CertReloader) reload() error { + // Get file modification times + certInfo, err := os.Stat(cr.certPath) + if err != nil { + return err + } + + keyInfo, err := os.Stat(cr.keyPath) + if err != nil { + return err + } + + certMod := certInfo.ModTime() + keyMod := keyInfo.ModTime() + + // Load the certificate + newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath) + if err != nil { + return err + } + + // Update the certificate and modification times + cr.certMu.Lock() + defer cr.certMu.Unlock() + + cr.cert = &newCert + cr.certMod = certMod + cr.keyMod = keyMod + + return nil +} + +// maybeReload checks if the certificate files have been modified and reloads them if necessary. +func (cr *CertReloader) maybeReload() error { + // Check cert file modification time + certInfo, err := os.Stat(cr.certPath) + if err != nil { + return err + } + + keyInfo, err := os.Stat(cr.keyPath) + if err != nil { + return err + } + + certMod := certInfo.ModTime() + keyMod := keyInfo.ModTime() + + // Check if files have been modified + cr.certMu.RLock() + needsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod) + cr.certMu.RUnlock() + + if needsReload { + return cr.reload() + } + + return nil +} + +// GetCertificateFunc returns a function that can be used as tls.Config.GetCertificate. +// This function checks for certificate updates on each TLS handshake and reloads if necessary. +func (cr *CertReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + // Try to reload the certificate if it has changed + // Ignore errors during reload attempts - keep using the existing certificate + _ = cr.maybeReload() + + cr.certMu.RLock() + defer cr.certMu.RUnlock() + + return cr.cert, nil + } +} diff --git a/pkg/api/tlscert_test.go b/pkg/api/tlscert_test.go new file mode 100644 index 00000000..0a739e9d --- /dev/null +++ b/pkg/api/tlscert_test.go @@ -0,0 +1,197 @@ +package api_test + +import ( + "crypto/tls" + "crypto/x509" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" + "gopkg.in/resty.v1" + + "zotregistry.dev/zot/v2/pkg/api" + "zotregistry.dev/zot/v2/pkg/api/config" + test "zotregistry.dev/zot/v2/pkg/test/common" + tlsutils "zotregistry.dev/zot/v2/pkg/test/tls" +) + +func TestTLSCertReload(t *testing.T) { + Convey("Test automatic TLS certificate reload", t, func() { + // Create temporary directory for certificates + tempDir := t.TempDir() + + // Generate initial CA certificate + caOpts := &tlsutils.CertificateOptions{ + CommonName: "Test CA", + NotAfter: time.Now().AddDate(1, 0, 0), + } + caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts) + So(err, ShouldBeNil) + + // Generate initial server certificate + serverCertPath := filepath.Join(tempDir, "server.cert") + serverKeyPath := filepath.Join(tempDir, "server.key") + serverOpts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "Server v1", + NotAfter: time.Now().AddDate(1, 0, 0), + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts) + So(err, ShouldBeNil) + + // Create config with TLS + port := test.GetFreePort() + httpsURL := test.GetSecureBaseURL(port) + + conf := config.New() + conf.HTTP.Address = "127.0.0.1" + conf.HTTP.Port = port + conf.HTTP.TLS = &config.TLSConfig{ + Cert: serverCertPath, + Key: serverKeyPath, + } + conf.Storage.RootDirectory = t.TempDir() + + ctlr := api.NewController(conf) + ctlr.Config.Storage.RootDirectory = t.TempDir() + + cm := test.NewControllerManager(ctlr) + cm.StartAndWait(port) + defer cm.StopServer() + + // Create client with CA certificate + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCertPEM) + + httpClient := resty.New(). + SetTLSClientConfig(&tls.Config{ + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + }). + SetRedirectPolicy(resty.FlexibleRedirectPolicy(10)) + + // Verify initial connection works with HTTPS + resp, err := httpClient.R().Get(httpsURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // Wait a moment to ensure file modification time will be different + time.Sleep(2 * time.Second) + + // Generate new server certificate with different CommonName + serverOpts2 := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "Server v2", + NotAfter: time.Now().AddDate(1, 0, 0), + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts2) + So(err, ShouldBeNil) + + // Wait for certificate to be detected and reloaded + time.Sleep(1 * time.Second) + + // Verify connection still works with new certificate + resp2, err := httpClient.R().Get(httpsURL + "/v2/") + So(err, ShouldBeNil) + So(resp2.StatusCode(), ShouldEqual, http.StatusOK) + }) +} + +func TestCertReloaderDirectly(t *testing.T) { + Convey("Test CertReloader functionality", t, func() { + tempDir := t.TempDir() + + // Generate CA certificate + caOpts := &tlsutils.CertificateOptions{ + CommonName: "Test CA", + NotAfter: time.Now().AddDate(1, 0, 0), + } + caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts) + So(err, ShouldBeNil) + + // Generate initial server certificate + certPath := filepath.Join(tempDir, "server.cert") + keyPath := filepath.Join(tempDir, "server.key") + serverOpts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "Initial Cert", + NotAfter: time.Now().AddDate(1, 0, 0), + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts) + So(err, ShouldBeNil) + + Convey("NewCertReloader should load initial certificate", func() { + reloader, err := api.NewCertReloader(certPath, keyPath) + So(err, ShouldBeNil) + So(reloader, ShouldNotBeNil) + + // Get certificate via callback + getCert := reloader.GetCertificateFunc() + cert, err := getCert(nil) + So(err, ShouldBeNil) + So(cert, ShouldNotBeNil) + }) + + Convey("GetCertificateFunc should reload when certificate changes", func() { + reloader, err := api.NewCertReloader(certPath, keyPath) + So(err, ShouldBeNil) + + getCert := reloader.GetCertificateFunc() + initialCert, err := getCert(nil) + So(err, ShouldBeNil) + So(initialCert, ShouldNotBeNil) + + // Wait to ensure modification time will be different + time.Sleep(2 * time.Second) + + // Generate new certificate + newServerOpts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "Updated Cert", + NotAfter: time.Now().AddDate(1, 0, 0), + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts) + So(err, ShouldBeNil) + + // Get certificate again - should reload automatically + updatedCert, err := getCert(nil) + So(err, ShouldBeNil) + So(updatedCert, ShouldNotBeNil) + + // Certificates should be different (different leaf certificates) + initialLeaf, err := x509.ParseCertificate(initialCert.Certificate[0]) + So(err, ShouldBeNil) + updatedLeaf, err := x509.ParseCertificate(updatedCert.Certificate[0]) + So(err, ShouldBeNil) + + // Common names should be different + So(initialLeaf.Subject.CommonName, ShouldEqual, "Initial Cert") + So(updatedLeaf.Subject.CommonName, ShouldEqual, "Updated Cert") + }) + + Convey("NewCertReloader should fail with invalid paths", func() { + _, err := api.NewCertReloader("/nonexistent/cert.pem", "/nonexistent/key.pem") + So(err, ShouldNotBeNil) + }) + + Convey("GetCertificateFunc should handle missing files gracefully", func() { + reloader, err := api.NewCertReloader(certPath, keyPath) + So(err, ShouldBeNil) + + getCert := reloader.GetCertificateFunc() + + // Delete the certificate file + err = os.Remove(certPath) + So(err, ShouldBeNil) + + // Should still return the old certificate (not fail) + cert, err := getCert(nil) + So(err, ShouldBeNil) + So(cert, ShouldNotBeNil) + }) + }) +}