Files
zot/pkg/api/tlscert_test.go
T
Ramkumar Chinchani 7ead92b82f feat(tls): automatically reload tls certs
Fixes issue #3747

Currently, zot requires a restart whenever tls certs change, which can
occur whenever there are tls cert rotation etc.

This PR checks if the tls certs have be modified and if so reloads them
without restarting zot.

Signed-off-by: Ramkumar Chinchani <rchincha.dev@gmail.com>
2026-02-01 14:48:44 -08:00

335 lines
9.9 KiB
Go

package api_test
import (
"crypto/tls"
"crypto/x509"
"net/http"
"os"
"path/filepath"
"testing"
"time"
. "github.com/smartystreets/goconvey/convey"
"gopkg.in/resty.v1"
"zotregistry.dev/zot/v2/pkg/api"
"zotregistry.dev/zot/v2/pkg/api/config"
"zotregistry.dev/zot/v2/pkg/log"
test "zotregistry.dev/zot/v2/pkg/test/common"
tlsutils "zotregistry.dev/zot/v2/pkg/test/tls"
)
func TestTLSCertReload(t *testing.T) {
Convey("Test automatic TLS certificate reload", t, func() {
// Create temporary directory for certificates
tempDir := t.TempDir()
// Generate initial CA certificate
caOpts := &tlsutils.CertificateOptions{
CommonName: "Test CA",
NotAfter: time.Now().AddDate(1, 0, 0),
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
So(err, ShouldBeNil)
// Generate initial server certificate
serverCertPath := filepath.Join(tempDir, "server.cert")
serverKeyPath := filepath.Join(tempDir, "server.key")
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Server v1",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts)
So(err, ShouldBeNil)
// Create config with TLS
port := test.GetFreePort()
httpsURL := test.GetSecureBaseURL(port)
conf := config.New()
conf.HTTP.Address = "127.0.0.1"
conf.HTTP.Port = port
conf.HTTP.TLS = &config.TLSConfig{
Cert: serverCertPath,
Key: serverKeyPath,
}
conf.Storage.RootDirectory = t.TempDir()
ctlr := api.NewController(conf)
ctlr.Config.Storage.RootDirectory = t.TempDir()
cm := test.NewControllerManager(ctlr)
cm.StartAndWait(port)
defer cm.StopServer()
// Create client with CA certificate
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCertPEM)
httpClient := resty.New().
SetTLSClientConfig(&tls.Config{
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
}).
SetRedirectPolicy(resty.FlexibleRedirectPolicy(10))
// Verify initial connection works with HTTPS
resp, err := httpClient.R().Get(httpsURL + "/v2/")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
// Wait a moment to ensure file modification time will be different
time.Sleep(2 * time.Second)
// Generate new server certificate with different CommonName
serverOpts2 := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Server v2",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts2)
So(err, ShouldBeNil)
// Wait for certificate to be detected and reloaded
time.Sleep(1 * time.Second)
// Verify connection still works with new certificate
resp2, err := httpClient.R().Get(httpsURL + "/v2/")
So(err, ShouldBeNil)
So(resp2.StatusCode(), ShouldEqual, http.StatusOK)
})
}
func TestCertReloaderDirectly(t *testing.T) {
Convey("Test CertReloader functionality", t, func() {
tempDir := t.TempDir()
// Generate CA certificate
caOpts := &tlsutils.CertificateOptions{
CommonName: "Test CA",
NotAfter: time.Now().AddDate(1, 0, 0),
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
So(err, ShouldBeNil)
// Generate initial server certificate
certPath := filepath.Join(tempDir, "server.cert")
keyPath := filepath.Join(tempDir, "server.key")
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Initial Cert",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts)
So(err, ShouldBeNil)
Convey("NewCertReloader should load initial certificate", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
So(reloader, ShouldNotBeNil)
defer reloader.Close()
// Get certificate via callback
getCert := reloader.GetCertificateFunc()
cert, err := getCert(nil)
So(err, ShouldBeNil)
So(cert, ShouldNotBeNil)
})
Convey("GetCertificateFunc should reload when certificate changes", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
initialCert, err := getCert(nil)
So(err, ShouldBeNil)
So(initialCert, ShouldNotBeNil)
// Wait to ensure modification time will be different
time.Sleep(2 * time.Second)
// Generate new certificate
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Updated Cert",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
So(err, ShouldBeNil)
// Get certificate again - should reload automatically
updatedCert, err := getCert(nil)
So(err, ShouldBeNil)
So(updatedCert, ShouldNotBeNil)
// Certificates should be different (different leaf certificates)
initialLeaf, err := x509.ParseCertificate(initialCert.Certificate[0])
So(err, ShouldBeNil)
updatedLeaf, err := x509.ParseCertificate(updatedCert.Certificate[0])
So(err, ShouldBeNil)
// Common names should be different
So(initialLeaf.Subject.CommonName, ShouldEqual, "Initial Cert")
So(updatedLeaf.Subject.CommonName, ShouldEqual, "Updated Cert")
})
Convey("NewCertReloader should fail with invalid paths", func() {
_, err := api.NewCertReloader("/nonexistent/cert.pem", "/nonexistent/key.pem", log.NewTestLogger())
So(err, ShouldNotBeNil)
})
Convey("GetCertificateFunc should handle missing files gracefully", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
// Delete the certificate file
err = os.Remove(certPath)
So(err, ShouldBeNil)
// Should still return the old certificate (not fail)
cert, err := getCert(nil)
So(err, ShouldBeNil)
So(cert, ShouldNotBeNil)
})
Convey("GetCertificateFunc should handle only cert file modification", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
initialCert, err := getCert(nil)
So(err, ShouldBeNil)
So(initialCert, ShouldNotBeNil)
// Wait to ensure modification time will be different
time.Sleep(2 * time.Second)
// Modify only the cert file (touch it to update mtime)
// Generate new cert with same key
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "Updated Cert Only",
NotAfter: time.Now().AddDate(1, 0, 0),
}
// Read the existing key
keyData, err := os.ReadFile(keyPath)
So(err, ShouldBeNil)
// Generate new cert using the existing key
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
So(err, ShouldBeNil)
// Restore original key
err = os.WriteFile(keyPath, keyData, 0o600)
So(err, ShouldBeNil)
// Get certificate again - should reload
updatedCert, err := getCert(nil)
So(err, ShouldBeNil)
So(updatedCert, ShouldNotBeNil)
})
Convey("GetCertificateFunc should handle concurrent access", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
// Launch multiple goroutines to access certificate concurrently
done := make(chan error, 10)
for range 10 {
go func() {
var lastErr error
for range 100 {
cert, err := getCert(nil)
if err != nil || cert == nil {
lastErr = err
break
}
}
done <- lastErr
}()
}
// Wait for all goroutines to complete and check for errors
for range 10 {
err := <-done
So(err, ShouldBeNil)
}
})
Convey("GetCertificateFunc should not reload if files haven't changed", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
// Get certificate multiple times
cert1, err := getCert(nil)
So(err, ShouldBeNil)
So(cert1, ShouldNotBeNil)
cert2, err := getCert(nil)
So(err, ShouldBeNil)
So(cert2, ShouldNotBeNil)
cert3, err := getCert(nil)
So(err, ShouldBeNil)
So(cert3, ShouldNotBeNil)
// All should return the same certificate instance (pointer equality)
So(cert1, ShouldEqual, cert2)
So(cert2, ShouldEqual, cert3)
})
Convey("GetCertificateFunc should reload when key file changes", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
initialCert, err := getCert(nil)
So(err, ShouldBeNil)
So(initialCert, ShouldNotBeNil)
// Wait to ensure modification time will be different
time.Sleep(2 * time.Second)
// Generate completely new cert and key
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "New Key Cert",
NotAfter: time.Now().AddDate(1, 0, 0),
}
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
So(err, ShouldBeNil)
// Get certificate again - should reload due to key change
updatedCert, err := getCert(nil)
So(err, ShouldBeNil)
So(updatedCert, ShouldNotBeNil)
// Verify certificates are different
initialLeaf, err := x509.ParseCertificate(initialCert.Certificate[0])
So(err, ShouldBeNil)
updatedLeaf, err := x509.ParseCertificate(updatedCert.Certificate[0])
So(err, ShouldBeNil)
So(initialLeaf.Subject.CommonName, ShouldEqual, "Initial Cert")
So(updatedLeaf.Subject.CommonName, ShouldEqual, "New Key Cert")
})
})
}