mirror of
https://github.com/project-zot/zot.git
synced 2026-06-18 13:37:57 +08:00
feat(tls): implement dynamic TLS certificate reloading with file watching (#3792)
Fixes issue #3747 Signed-off-by: Ramkumar Chinchani <rchincha.dev@gmail.com>
This commit is contained in:
committed by
GitHub
parent
2c110d2c20
commit
47659c11b2
@@ -204,4 +204,6 @@ var (
|
||||
ErrOIDCEmptyValidationMsg = errors.New("validation error message is empty")
|
||||
ErrOIDCValidationFailed = errors.New("OIDC claim validation failed")
|
||||
ErrOIDCAudienceMismatch = errors.New("token audience does not match any of the expected audiences")
|
||||
ErrCertificateNotLoaded = errors.New("tls certificate not yet loaded")
|
||||
ErrCertificateWatcherAlreadyRunning = errors.New("certificate watcher is already running")
|
||||
)
|
||||
|
||||
+44
-1
@@ -12,6 +12,7 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
@@ -58,6 +59,8 @@ type Controller struct {
|
||||
Healthz *common.Healthz
|
||||
// runtime params
|
||||
chosenPort int // kernel-chosen port
|
||||
// TLS certificate management
|
||||
TlsWatcher atomic.Pointer[TlsConfigWatcher]
|
||||
}
|
||||
|
||||
func NewController(appConfig *config.Config) *Controller {
|
||||
@@ -242,6 +245,12 @@ func (c *Controller) Run() error {
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
// Load CA certificate for mTLS client verification
|
||||
// Note: CA certificate is loaded statically. Unlike the server certificate which is reloaded
|
||||
// dynamically through the file watcher, CA certificate changes require a server restart.
|
||||
// If dynamic CA certificate rotation is needed in the future (e.g., for rotating mTLS client certificates),
|
||||
// the TlsConfigWatcher can be extended to also monitor and reload the CA certificate and update
|
||||
// server.TLSConfig.ClientCAs accordingly.
|
||||
if tlsConfig.CACert != "" {
|
||||
caCert, err := os.ReadFile(tlsConfig.CACert)
|
||||
if err != nil {
|
||||
@@ -264,9 +273,37 @@ func (c *Controller) Run() error {
|
||||
server.TLSConfig.ClientCAs = caCertPool
|
||||
}
|
||||
|
||||
// Store TLS config paths in watcher for dynamic reloading
|
||||
tlsCertPath := tlsConfig.Cert
|
||||
tlsKeyPath := tlsConfig.Key
|
||||
|
||||
// Create and start certificate watcher
|
||||
// Note: The watcher is stored in c.TlsWatcher before calling Start() so it can be properly
|
||||
// cleaned up during Shutdown even if Start fails. The Stop() method gracefully handles
|
||||
// the case where Start() was never called or failed by checking if the done channel is nil.
|
||||
watcher := NewTlsConfigWatcher(tlsCertPath, tlsKeyPath, c.Log)
|
||||
c.TlsWatcher.Store(watcher)
|
||||
defer watcher.Stop()
|
||||
|
||||
// Load initial certificate
|
||||
if err := watcher.ReloadCertificate(); err != nil {
|
||||
c.Log.Error().Err(err).Msg("failed to load initial certificate")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Start file watching for certificate reloading
|
||||
if err := watcher.Start(); err != nil {
|
||||
c.Log.Warn().Err(err).Msg("failed to start certificate watcher, will use fallback polling")
|
||||
}
|
||||
|
||||
// Set GetCertificate callback for dynamic certificate reloading
|
||||
server.TLSConfig.GetCertificate = watcher.GetCertificate
|
||||
|
||||
c.Healthz.Ready()
|
||||
|
||||
return server.ServeTLS(listener, tlsConfig.Cert, tlsConfig.Key)
|
||||
// Pass empty strings to ServeTLS since GetCertificate handles certificate loading
|
||||
return server.ServeTLS(listener, "", "")
|
||||
}
|
||||
|
||||
c.Healthz.Ready()
|
||||
@@ -455,6 +492,12 @@ func (c *Controller) LoadNewConfig(newConfig *config.Config) {
|
||||
}
|
||||
|
||||
func (c *Controller) Shutdown() {
|
||||
// Stop certificate watcher if it's running
|
||||
watcher := c.TlsWatcher.Load()
|
||||
if watcher != nil {
|
||||
watcher.Stop()
|
||||
}
|
||||
|
||||
// stop all background tasks
|
||||
c.StopBackgroundTasks()
|
||||
|
||||
|
||||
@@ -0,0 +1,405 @@
|
||||
//go:build sync && scrub && metrics && search && lint && userprefs && mgmt && imagetrust && ui
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
goerrors "errors"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zotregistry.dev/zot/v2/pkg/log"
|
||||
tlsutils "zotregistry.dev/zot/v2/pkg/test/tls"
|
||||
)
|
||||
|
||||
var errGetCertificateFailed = goerrors.New("GetCertificate failed")
|
||||
|
||||
func TestReloadCertificateStatFailureKeepsModTimes(t *testing.T) {
|
||||
logger := log.NewLogger("debug", "")
|
||||
tempDir := t.TempDir()
|
||||
|
||||
certPath := path.Join(tempDir, "cert.pem")
|
||||
keyPath := path.Join(tempDir, "key.pem")
|
||||
|
||||
caOpts := &tlsutils.CertificateOptions{
|
||||
CommonName: "*",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate CA cert: %v", err)
|
||||
}
|
||||
|
||||
serverOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "*",
|
||||
OrganizationalUnit: "TestServer",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
|
||||
t.Fatalf("failed to generate server cert: %v", err)
|
||||
}
|
||||
|
||||
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
|
||||
|
||||
if err := watcher.ReloadCertificate(); err != nil {
|
||||
t.Fatalf("failed to load initial certificate: %v", err)
|
||||
}
|
||||
|
||||
watcher.mu.RLock()
|
||||
initialCertModTime := watcher.tlsCertModTime
|
||||
initialKeyModTime := watcher.tlsKeyModTime
|
||||
watcher.mu.RUnlock()
|
||||
|
||||
if initialCertModTime.IsZero() || initialKeyModTime.IsZero() {
|
||||
t.Fatal("expected initial mod times to be set")
|
||||
}
|
||||
|
||||
originalStat := tlsFileStat
|
||||
tlsFileStat = func(string) (os.FileInfo, error) {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
tlsFileStat = originalStat
|
||||
})
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if err := watcher.ReloadCertificate(); err != nil {
|
||||
t.Fatalf("unexpected reload error when stat fails: %v", err)
|
||||
}
|
||||
|
||||
watcher.mu.RLock()
|
||||
updatedCertModTime := watcher.tlsCertModTime
|
||||
updatedKeyModTime := watcher.tlsKeyModTime
|
||||
watcher.mu.RUnlock()
|
||||
|
||||
if !updatedCertModTime.Equal(initialCertModTime) {
|
||||
t.Fatal("certificate mod time changed despite stat failure")
|
||||
}
|
||||
if !updatedKeyModTime.Equal(initialKeyModTime) {
|
||||
t.Fatal("key mod time changed despite stat failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCertificateReloadConcurrency(t *testing.T) {
|
||||
logger := log.NewLogger("debug", "")
|
||||
tempDir := t.TempDir()
|
||||
|
||||
certPath := path.Join(tempDir, "cert.pem")
|
||||
keyPath := path.Join(tempDir, "key.pem")
|
||||
|
||||
caOpts := &tlsutils.CertificateOptions{
|
||||
CommonName: "*",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate CA cert: %v", err)
|
||||
}
|
||||
|
||||
serverOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "*",
|
||||
OrganizationalUnit: "TestServer",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
|
||||
t.Fatalf("failed to generate server cert: %v", err)
|
||||
}
|
||||
|
||||
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
|
||||
|
||||
if err := watcher.ReloadCertificate(); err != nil {
|
||||
t.Fatalf("failed to load initial certificate: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errorCh := make(chan error, 32)
|
||||
|
||||
reloadWorker := func(iterations int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
if err := watcher.ReloadCertificate(); err != nil {
|
||||
errorCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getWorker := func(iterations int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
cert, err := watcher.GetCertificate(nil)
|
||||
if err != nil || cert == nil {
|
||||
errorCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(3)
|
||||
go reloadWorker(50)
|
||||
go getWorker(100)
|
||||
go getWorker(100)
|
||||
|
||||
wg.Wait()
|
||||
close(errorCh)
|
||||
|
||||
for err := range errorCh {
|
||||
if err != nil {
|
||||
t.Fatalf("concurrent TLS operations failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCertificateInitialLoadConcurrency(t *testing.T) {
|
||||
logger := log.NewLogger("debug", "")
|
||||
tempDir := t.TempDir()
|
||||
|
||||
certPath := path.Join(tempDir, "cert.pem")
|
||||
keyPath := path.Join(tempDir, "key.pem")
|
||||
|
||||
caOpts := &tlsutils.CertificateOptions{
|
||||
CommonName: "*",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate CA cert: %v", err)
|
||||
}
|
||||
|
||||
serverOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "*",
|
||||
OrganizationalUnit: "TestServer",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
|
||||
t.Fatalf("failed to generate server cert: %v", err)
|
||||
}
|
||||
|
||||
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make(chan error, 16)
|
||||
|
||||
worker := func(iterations int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
cert, err := watcher.GetCertificate(nil)
|
||||
if err == nil && cert == nil {
|
||||
results <- errGetCertificateFailed
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
results <- nil
|
||||
}
|
||||
|
||||
for range 8 {
|
||||
wg.Add(1)
|
||||
go worker(10)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
for err := range results {
|
||||
if err != nil {
|
||||
t.Fatalf("concurrent initial TLS load failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartCertificateWatcherAddFailureFallsBack(t *testing.T) {
|
||||
logger := log.NewLogger("debug", "")
|
||||
tempDir := t.TempDir()
|
||||
|
||||
keyPath := path.Join(tempDir, "key.pem")
|
||||
if err := os.WriteFile(keyPath, []byte("key"), 0o600); err != nil {
|
||||
t.Fatalf("failed to write key file: %v", err)
|
||||
}
|
||||
|
||||
watcher := NewTlsConfigWatcher(path.Join(tempDir, "missing-cert.pem"), keyPath, logger)
|
||||
|
||||
if err := watcher.Start(); err == nil {
|
||||
t.Fatal("expected watcher setup to fail for missing certificate file")
|
||||
}
|
||||
if watcher.UseInotify() {
|
||||
t.Fatal("expected file watching to be disabled after watcher add failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateWatcherHandlesAtomicRename(t *testing.T) {
|
||||
logger := log.NewLogger("debug", "")
|
||||
tempDir := t.TempDir()
|
||||
|
||||
certPath := path.Join(tempDir, "cert.pem")
|
||||
keyPath := path.Join(tempDir, "key.pem")
|
||||
|
||||
// Generate initial certificates
|
||||
caOpts := &tlsutils.CertificateOptions{
|
||||
CommonName: "*",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate CA cert: %v", err)
|
||||
}
|
||||
|
||||
serverOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "*",
|
||||
OrganizationalUnit: "TestServer-Initial",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
|
||||
t.Fatalf("failed to generate initial server cert: %v", err)
|
||||
}
|
||||
|
||||
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
|
||||
|
||||
if err := watcher.ReloadCertificate(); err != nil {
|
||||
t.Fatalf("failed to load initial certificate: %v", err)
|
||||
}
|
||||
|
||||
if err := watcher.Start(); err != nil {
|
||||
t.Fatalf("failed to start certificate watcher: %v", err)
|
||||
}
|
||||
defer watcher.Stop()
|
||||
|
||||
watcher.mu.RLock()
|
||||
initialModTime := watcher.tlsCertModTime
|
||||
watcher.mu.RUnlock()
|
||||
|
||||
// Simulate atomic certificate replacement via rename (common pattern in k8s, cert-manager, etc.)
|
||||
// Sleep long enough to ensure filesystem records different timestamps (at least 1 second
|
||||
// to account for systems with second-level timestamp precision)
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
|
||||
newServerOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.2",
|
||||
CommonName: "*",
|
||||
OrganizationalUnit: "TestServer-Rotated",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
|
||||
// Write to temp files first
|
||||
tempCertPath := certPath + ".tmp"
|
||||
tempKeyPath := keyPath + ".tmp"
|
||||
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, tempCertPath, tempKeyPath, newServerOpts); err != nil {
|
||||
t.Fatalf("failed to generate temp server cert: %v", err)
|
||||
}
|
||||
|
||||
// Atomic rename (this is how most certificate rotation works)
|
||||
if err := os.Rename(tempCertPath, certPath); err != nil {
|
||||
t.Fatalf("failed to rename cert: %v", err)
|
||||
}
|
||||
if err := os.Rename(tempKeyPath, keyPath); err != nil {
|
||||
t.Fatalf("failed to rename key: %v", err)
|
||||
}
|
||||
|
||||
// Poll until the certificate is reloaded or timeout is reached.
|
||||
// This accounts for fsnotify delivery time plus the 150ms debounce interval,
|
||||
// which can exceed fixed sleeps on slower systems.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
var newModTime time.Time
|
||||
reloaded := false
|
||||
for time.Now().Before(deadline) {
|
||||
watcher.mu.RLock()
|
||||
newModTime = watcher.tlsCertModTime
|
||||
watcher.mu.RUnlock()
|
||||
|
||||
if newModTime.After(initialModTime) {
|
||||
reloaded = true
|
||||
break
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
if !reloaded {
|
||||
t.Fatalf("expected certificate to be reloaded after atomic rename within 2s, initial: %v, new: %v",
|
||||
initialModTime, newModTime)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateWatcherCanRestart(t *testing.T) {
|
||||
logger := log.NewLogger("debug", "")
|
||||
tempDir := t.TempDir()
|
||||
|
||||
certPath := path.Join(tempDir, "cert.pem")
|
||||
keyPath := path.Join(tempDir, "key.pem")
|
||||
|
||||
caOpts := &tlsutils.CertificateOptions{
|
||||
CommonName: "*",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate CA cert: %v", err)
|
||||
}
|
||||
|
||||
serverOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "*",
|
||||
OrganizationalUnit: "TestServer",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
|
||||
t.Fatalf("failed to generate server cert: %v", err)
|
||||
}
|
||||
|
||||
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
|
||||
|
||||
// Load initial certificate
|
||||
if err := watcher.ReloadCertificate(); err != nil {
|
||||
t.Fatalf("failed to load initial certificate: %v", err)
|
||||
}
|
||||
|
||||
// Start the watcher
|
||||
if err := watcher.Start(); err != nil {
|
||||
t.Fatalf("failed to start certificate watcher: %v", err)
|
||||
}
|
||||
|
||||
// Verify UseInotify returns true while running
|
||||
if !watcher.UseInotify() {
|
||||
t.Fatalf("expected UseInotify() to be true after Start()")
|
||||
}
|
||||
|
||||
// Stop the watcher
|
||||
watcher.Stop()
|
||||
time.Sleep(100 * time.Millisecond) // Allow cleanup to complete
|
||||
|
||||
// Verify UseInotify returns false after Stop
|
||||
if watcher.UseInotify() {
|
||||
t.Fatalf("expected UseInotify() to be false after Stop()")
|
||||
}
|
||||
|
||||
// Verify we can call Start() again (should not return error)
|
||||
if err := watcher.Start(); err != nil {
|
||||
t.Fatalf("failed to restart certificate watcher: %v", err)
|
||||
}
|
||||
|
||||
// Verify UseInotify is true again
|
||||
if !watcher.UseInotify() {
|
||||
t.Fatalf("expected UseInotify() to be true after restart")
|
||||
}
|
||||
|
||||
watcher.Stop()
|
||||
}
|
||||
@@ -13607,3 +13607,163 @@ func readTagsFromStorage(rootDir, repoName string, digest godigest.Digest) ([]st
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var errGetCertificateFailed = goerrors.New("GetCertificate failed")
|
||||
|
||||
func TestDynamicTLSCertificateReloading(t *testing.T) {
|
||||
Convey("Test dynamic TLS certificate reloading", t, func() {
|
||||
logger := log.NewLogger("debug", "")
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Generate initial certificate and key
|
||||
certPath := path.Join(tempDir, "cert.pem")
|
||||
keyPath := path.Join(tempDir, "key.pem")
|
||||
|
||||
caOpts := &tlsutils.CertificateOptions{
|
||||
CommonName: "*",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
serverOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.1",
|
||||
CommonName: "*",
|
||||
OrganizationalUnit: "TestServer",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Create a watcher and set certificate paths
|
||||
watcher := api.NewTlsConfigWatcher(certPath, keyPath, logger)
|
||||
|
||||
// Test 1: Load initial certificate
|
||||
Convey("Load initial certificate successfully", func() {
|
||||
err := watcher.ReloadCertificate()
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cert, err := watcher.GetCertificate(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(cert, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
// Test 2: GetCertificate returns the loaded certificate
|
||||
Convey("GetCertificate returns the loaded certificate", func() {
|
||||
err := watcher.ReloadCertificate()
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cert, err := watcher.GetCertificate(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(cert, ShouldNotBeNil)
|
||||
|
||||
certAgain, err := watcher.GetCertificate(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(certAgain, ShouldEqual, cert)
|
||||
})
|
||||
|
||||
// Test 3: Certificate change detection via stat-based fallback
|
||||
Convey("Detect certificate change and reload via stat fallback", func() {
|
||||
// Load initial certificate
|
||||
err := watcher.ReloadCertificate()
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
oldCert, err := watcher.GetCertificate(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(oldCert, ShouldNotBeNil)
|
||||
|
||||
oldLeaf, err := x509.ParseCertificate(oldCert.Certificate[0])
|
||||
So(err, ShouldBeNil)
|
||||
So(oldLeaf, ShouldNotBeNil)
|
||||
|
||||
// Wait long enough to ensure different modification time on coarse timestamp filesystems
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
|
||||
// Generate a new certificate
|
||||
newServerOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "127.0.0.2",
|
||||
CommonName: "*",
|
||||
OrganizationalUnit: "TestServer-New",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Call GetCertificate which should trigger reload via stat-based checking
|
||||
cert, err := watcher.GetCertificate(nil)
|
||||
So(err, ShouldBeNil)
|
||||
So(cert, ShouldNotBeNil)
|
||||
|
||||
newLeaf, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
So(err, ShouldBeNil)
|
||||
So(newLeaf, ShouldNotBeNil)
|
||||
|
||||
// Verify that the certificate content was updated
|
||||
So(newLeaf.Subject.OrganizationalUnit, ShouldNotResemble, oldLeaf.Subject.OrganizationalUnit)
|
||||
})
|
||||
|
||||
// Test 4: checkCertificateModTime correctly detects changes
|
||||
Convey("checkCertificateModTime correctly detects file modifications", func() {
|
||||
// Load initial certificate
|
||||
err := watcher.ReloadCertificate()
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// No changes yet, should return false
|
||||
needsReload := watcher.CheckCertificateModTime()
|
||||
So(needsReload, ShouldBeFalse)
|
||||
|
||||
// Wait and modify certificate file
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
newServerOpts := &tlsutils.CertificateOptions{
|
||||
Hostname: "test.example.com",
|
||||
CommonName: "*",
|
||||
OrganizationalUnit: "TestServer-Changed",
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
KeyType: tlsutils.KeyTypeECDSA,
|
||||
}
|
||||
err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, newServerOpts)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Now should detect the change
|
||||
needsReload = watcher.CheckCertificateModTime()
|
||||
So(needsReload, ShouldBeTrue)
|
||||
})
|
||||
|
||||
// Test 5: LoadX509KeyPair error handling
|
||||
Convey("Handle certificate loading errors gracefully", func() {
|
||||
badWatcher := api.NewTlsConfigWatcher("/nonexistent/cert.pem", "/nonexistent/key.pem", logger)
|
||||
|
||||
err := badWatcher.ReloadCertificate()
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
// Test 6: Concurrent GetCertificate calls with certificate reloading
|
||||
Convey("Handle concurrent GetCertificate calls safely", func() {
|
||||
err := watcher.ReloadCertificate()
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Simulate concurrent calls
|
||||
done := make(chan error, 5)
|
||||
|
||||
for range 5 {
|
||||
go func() {
|
||||
cert, err := watcher.GetCertificate(nil)
|
||||
if err != nil || cert == nil {
|
||||
done <- errGetCertificateFailed
|
||||
} else {
|
||||
done <- nil
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for range 5 {
|
||||
err := <-done
|
||||
So(err, ShouldBeNil)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
+7
-3
@@ -588,7 +588,7 @@ func getReferrers(ctx context.Context, routeHandler *RouteHandler,
|
||||
imgStore storageTypes.ImageStore, name string, digest godigest.Digest,
|
||||
artifactTypes []string,
|
||||
) (ispec.Index, error) {
|
||||
if isSyncOnDemandEnabled(*routeHandler.c) {
|
||||
if isSyncOnDemandEnabled(routeHandler.c) {
|
||||
routeHandler.c.Log.Info().Str("repository", name).Str("reference", digest.String()).
|
||||
Msg("trying to get updated referrers by syncing on demand")
|
||||
|
||||
@@ -2167,7 +2167,7 @@ func (rh *RouteHandler) getImageStore(name string) storageTypes.ImageStore {
|
||||
func getImageManifest(ctx context.Context, routeHandler *RouteHandler, imgStore storageTypes.ImageStore, name,
|
||||
reference string,
|
||||
) ([]byte, godigest.Digest, string, error) {
|
||||
syncEnabled := isSyncOnDemandEnabled(*routeHandler.c)
|
||||
syncEnabled := isSyncOnDemandEnabled(routeHandler.c)
|
||||
|
||||
_, digestErr := godigest.Parse(reference)
|
||||
if digestErr == nil {
|
||||
@@ -2400,7 +2400,11 @@ func getBlobUploadLocation(url *url.URL, name string, digest godigest.Digest) st
|
||||
return url.String()
|
||||
}
|
||||
|
||||
func isSyncOnDemandEnabled(ctlr Controller) bool {
|
||||
func isSyncOnDemandEnabled(ctlr *Controller) bool {
|
||||
if ctlr == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
extensionsConfig := ctlr.Config.CopyExtensionsConfig()
|
||||
if extensionsConfig.IsSyncEnabled() &&
|
||||
fmt.Sprintf("%v", ctlr.SyncOnDemand) != fmt.Sprintf("%v", nil) {
|
||||
|
||||
@@ -0,0 +1,472 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
|
||||
"zotregistry.dev/zot/v2/errors"
|
||||
"zotregistry.dev/zot/v2/pkg/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// tlsCertificateEventDebounceInterval is the time window for coalescing multiple file change events
|
||||
// into a single reload operation. This prevents redundant reloads when both cert and key files are modified.
|
||||
tlsCertificateEventDebounceInterval = 150 * time.Millisecond
|
||||
// tlsCertificateStatCheckInterval is the rate limit for stat-based certificate change detection.
|
||||
tlsCertificateStatCheckInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
var tlsFileStat = os.Stat //nolint:gochecknoglobals // test hook for os.Stat
|
||||
|
||||
// TlsConfigWatcher watches TLS cert/key files and reloads certificates on change.
|
||||
type TlsConfigWatcher struct {
|
||||
mu sync.RWMutex
|
||||
watcher *fsnotify.Watcher
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
useInotify bool
|
||||
certPath string
|
||||
keyPath string
|
||||
log log.Logger
|
||||
debounceTimer *time.Timer
|
||||
debounceMutex sync.Mutex
|
||||
// Certificate management fields
|
||||
tlsCert *tls.Certificate
|
||||
tlsCertModTime time.Time
|
||||
tlsKeyModTime time.Time
|
||||
tlsCertReloadInProgress atomic.Bool
|
||||
tlsCertLastStatCheckTime atomic.Int64 // Unix timestamp in nanoseconds for last stat check
|
||||
}
|
||||
|
||||
// NewTlsConfigWatcher creates a TLS config watcher for the given cert and key paths.
|
||||
func NewTlsConfigWatcher(certPath, keyPath string, logger log.Logger) *TlsConfigWatcher {
|
||||
return &TlsConfigWatcher{
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
log: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins watching the certificate and key files for changes.
|
||||
// Returns an error if the watcher is already running. Safe to call multiple times if Stop is called
|
||||
// between attempts.
|
||||
func (w *TlsConfigWatcher) Start() error {
|
||||
// Check if watcher is already running to prevent duplicate goroutines and resource leaks
|
||||
w.mu.RLock()
|
||||
if w.done != nil {
|
||||
w.mu.RUnlock()
|
||||
|
||||
return errors.ErrCertificateWatcherAlreadyRunning
|
||||
}
|
||||
w.mu.RUnlock()
|
||||
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
w.log.Error().Err(err).Msg("failed to create fsnotify watcher")
|
||||
w.disableUseInotify()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
w.log.Debug().Str("cert", w.certPath).Str("key", w.keyPath).
|
||||
Msg("starting TLS certificate watcher")
|
||||
|
||||
if err := watcher.Add(w.certPath); err != nil {
|
||||
w.log.Warn().Err(err).Str("cert", w.certPath).Msg("failed to watch certificate file")
|
||||
_ = watcher.Close()
|
||||
w.disableUseInotify()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if err := watcher.Add(w.keyPath); err != nil {
|
||||
w.log.Warn().Err(err).Str("key", w.keyPath).Msg("failed to watch key file")
|
||||
_ = watcher.Close()
|
||||
w.disableUseInotify()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
w.watcher = watcher
|
||||
w.useInotify = true
|
||||
w.done = make(chan struct{})
|
||||
w.stopOnce = sync.Once{}
|
||||
w.mu.Unlock()
|
||||
|
||||
go w.loop()
|
||||
|
||||
w.log.Info().Msg("TLS certificate watcher started using fsnotify")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop signals the watcher to stop and returns once the signal is sent.
|
||||
// Safe to call even if Start() was never called or failed - it will return early if the watcher
|
||||
// goroutine is not running.
|
||||
func (w *TlsConfigWatcher) Stop() {
|
||||
w.mu.RLock()
|
||||
done := w.done
|
||||
w.mu.RUnlock()
|
||||
|
||||
if done == nil {
|
||||
w.log.Debug().Msg("TLS certificate watcher stop requested with no active watcher")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.stopOnce.Do(func() {
|
||||
// Clean up any pending debounce timer
|
||||
w.debounceMutex.Lock()
|
||||
if w.debounceTimer != nil {
|
||||
w.debounceTimer.Stop()
|
||||
w.debounceTimer = nil
|
||||
}
|
||||
w.debounceMutex.Unlock()
|
||||
|
||||
// Atomically capture and reset watcher state
|
||||
w.mu.Lock()
|
||||
capturedWatcher := w.watcher
|
||||
capturedDone := w.done
|
||||
w.done = nil
|
||||
w.watcher = nil
|
||||
w.useInotify = false
|
||||
w.mu.Unlock()
|
||||
|
||||
// Close fsnotify watcher to terminate the goroutine promptly
|
||||
if capturedWatcher != nil {
|
||||
_ = capturedWatcher.Close()
|
||||
}
|
||||
|
||||
// Signal the goroutine to exit via the done channel
|
||||
if capturedDone != nil {
|
||||
close(capturedDone)
|
||||
}
|
||||
|
||||
w.log.Debug().Msg("TLS certificate watcher stopped and state cleared")
|
||||
})
|
||||
}
|
||||
|
||||
// UseInotify reports whether file watching is active.
|
||||
func (w *TlsConfigWatcher) UseInotify() bool {
|
||||
w.mu.RLock()
|
||||
useInotify := w.useInotify
|
||||
w.mu.RUnlock()
|
||||
|
||||
return useInotify
|
||||
}
|
||||
|
||||
func (w *TlsConfigWatcher) disableUseInotify() {
|
||||
w.mu.Lock()
|
||||
w.useInotify = false
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
func (w *TlsConfigWatcher) loop() {
|
||||
// Clear watcher state when loop exits so Start() can be called again.
|
||||
// This serves as a safety net for cases where the goroutine exits naturally
|
||||
// (e.g., channels closed unexpectedly) without Stop() being called.
|
||||
defer func() {
|
||||
w.mu.Lock()
|
||||
// Close done channel if not already closed by Stop()
|
||||
if w.done != nil {
|
||||
close(w.done)
|
||||
w.done = nil
|
||||
}
|
||||
if w.watcher != nil {
|
||||
_ = w.watcher.Close()
|
||||
w.watcher = nil
|
||||
}
|
||||
w.useInotify = false
|
||||
// Note: w.stopOnce is reset by Start() when a new watcher is initialized,
|
||||
// so we don't reset it here to avoid races with Stop() executing its callback
|
||||
w.mu.Unlock()
|
||||
w.log.Debug().Msg("TLS certificate watcher loop cleanup completed")
|
||||
}()
|
||||
|
||||
w.mu.RLock()
|
||||
watcher := w.watcher
|
||||
done := w.done
|
||||
w.mu.RUnlock()
|
||||
|
||||
if watcher == nil {
|
||||
w.log.Debug().Msg("TLS certificate watcher loop exited: watcher not initialized")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
w.log.Debug().Msg("TLS certificate watcher loop exited: stop signal received")
|
||||
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
w.log.Debug().Msg("TLS certificate watcher loop exited: events channel closed")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 {
|
||||
w.log.Debug().Str("file", event.Name).Str("op", event.Op.String()).
|
||||
Msg("certificate file change detected")
|
||||
|
||||
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
||||
if !w.retryAddWatch(event.Name, watcher, done) {
|
||||
w.log.Warn().Str("file", event.Name).
|
||||
Msg("failed to re-add watch after retries, switching to stat-based polling")
|
||||
w.disableUseInotify()
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
w.log.Debug().Msg("TLS certificate watcher loop exited: stop signal received before reload")
|
||||
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Debounce multiple file events to coalesce cert and key changes into a single reload
|
||||
w.scheduleReload()
|
||||
}
|
||||
case <-w.getDebounceChannel():
|
||||
// Debounce timer expired, perform the reload
|
||||
w.debounceMutex.Lock()
|
||||
w.debounceTimer = nil
|
||||
w.debounceMutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
w.log.Debug().Msg("TLS certificate watcher loop exited: stop signal received before debounced reload")
|
||||
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
w.log.Debug().Str("cert", w.certPath).Str("key", w.keyPath).
|
||||
Msg("reloading TLS certificate after debounced file change")
|
||||
if err := w.ReloadCertificate(); err != nil {
|
||||
w.log.Error().Err(err).Msg("failed to reload certificate on file change")
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
w.log.Debug().Msg("TLS certificate watcher loop exited: errors channel closed")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.log.Error().Err(err).Msg("failed to watch certificate files")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *TlsConfigWatcher) getDebounceChannel() <-chan time.Time {
|
||||
w.debounceMutex.Lock()
|
||||
defer w.debounceMutex.Unlock()
|
||||
|
||||
if w.debounceTimer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.debounceTimer.C
|
||||
}
|
||||
|
||||
func (w *TlsConfigWatcher) scheduleReload() {
|
||||
w.debounceMutex.Lock()
|
||||
defer w.debounceMutex.Unlock()
|
||||
|
||||
// If a reload is already pending, just reset the timer to restart the debounce window
|
||||
if w.debounceTimer != nil {
|
||||
if !w.debounceTimer.Stop() {
|
||||
select {
|
||||
case <-w.debounceTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
w.debounceTimer.Reset(tlsCertificateEventDebounceInterval)
|
||||
w.log.Debug().Msg("debounce timer reset for additional file change event")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// First event after debounce window closed, schedule a new reload.
|
||||
// Use time.NewTimer instead of time.AfterFunc since we only care about the timer's
|
||||
// channel firing in the select statement, not about executing a callback function.
|
||||
w.debounceTimer = time.NewTimer(tlsCertificateEventDebounceInterval)
|
||||
w.log.Debug().Str("interval", tlsCertificateEventDebounceInterval.String()).
|
||||
Msg("debounce timer started for file change events")
|
||||
}
|
||||
|
||||
func (w *TlsConfigWatcher) retryAddWatch(file string, watcher *fsnotify.Watcher, done <-chan struct{}) bool {
|
||||
for attempt := range 5 {
|
||||
select {
|
||||
case <-done:
|
||||
return false
|
||||
case <-time.After(time.Duration(50*(attempt+1)) * time.Millisecond):
|
||||
}
|
||||
|
||||
if err := watcher.Add(file); err == nil {
|
||||
w.log.Debug().Str("file", file).Int("attempt", attempt+1).
|
||||
Msg("re-added watch after file removal/rename")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
w.log.Debug().Str("file", file).Int("attempt", attempt+1).
|
||||
Msg("retrying watch add after failure")
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetCertificate is a callback used by tls.Config that dynamically loads TLS certificates.
|
||||
// This allows certificates to be reloaded when they change on disk without restarting the server.
|
||||
// It uses fsnotify for file watching when available or falls back to stat-based checking.
|
||||
func (w *TlsConfigWatcher) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
var needsReload bool
|
||||
|
||||
var cert *tls.Certificate
|
||||
|
||||
// First check if certificate is not yet loaded
|
||||
w.mu.RLock()
|
||||
cert = w.tlsCert
|
||||
if cert == nil {
|
||||
w.mu.RUnlock()
|
||||
needsReload = true
|
||||
} else {
|
||||
w.mu.RUnlock()
|
||||
useInotify := w.UseInotify()
|
||||
|
||||
// If file watching is not being used, perform stat-based fallback polling.
|
||||
// Rate limit the stat checks to avoid I/O overhead on every TLS handshake.
|
||||
if !useInotify {
|
||||
now := time.Now().UnixNano()
|
||||
lastCheckTime := w.tlsCertLastStatCheckTime.Load()
|
||||
if now-lastCheckTime >= int64(tlsCertificateStatCheckInterval) {
|
||||
needsReload = w.CheckCertificateModTime()
|
||||
// Update the last check time only if we actually performed the check
|
||||
w.tlsCertLastStatCheckTime.Store(now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only reload if necessary or if certificate is not yet loaded.
|
||||
// Use atomic flag to ensure only one goroutine performs the reload when multiple
|
||||
// concurrent requests detect a stale certificate, preventing duplicate reload operations.
|
||||
//
|
||||
// If another goroutine is already performing a reload, wait for it to complete before
|
||||
// returning to avoid spurious ErrCertificateNotLoaded errors during concurrent initial loads.
|
||||
if needsReload {
|
||||
// Only proceed with reload if no reload is already in progress
|
||||
if w.tlsCertReloadInProgress.CompareAndSwap(false, true) {
|
||||
defer w.tlsCertReloadInProgress.Store(false)
|
||||
if err := w.ReloadCertificate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to reload certificate: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Another goroutine is reloading - wait for it to complete (up to 5 seconds)
|
||||
// to avoid returning ErrCertificateNotLoaded during concurrent initial loads
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
for w.tlsCertReloadInProgress.Load() && time.Now().Before(deadline) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.mu.RLock()
|
||||
cert = w.tlsCert
|
||||
w.mu.RUnlock()
|
||||
|
||||
if cert == nil {
|
||||
return nil, errors.ErrCertificateNotLoaded
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// ReloadCertificate loads the TLS certificate and key from disk.
|
||||
func (w *TlsConfigWatcher) ReloadCertificate() error {
|
||||
cert, err := tls.LoadX509KeyPair(w.certPath, w.keyPath)
|
||||
if err != nil {
|
||||
w.log.Error().Err(err).Str("cert", w.certPath).Str("key", w.keyPath).
|
||||
Msg("failed to load certificate and key pair")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Update modification times
|
||||
certInfo, certStatErr := tlsFileStat(w.certPath)
|
||||
if certStatErr != nil {
|
||||
w.log.Warn().Err(certStatErr).Str("cert", w.certPath).
|
||||
Msg("failed to stat certificate file")
|
||||
}
|
||||
keyInfo, keyStatErr := tlsFileStat(w.keyPath)
|
||||
if keyStatErr != nil {
|
||||
w.log.Warn().Err(keyStatErr).Str("key", w.keyPath).
|
||||
Msg("failed to stat key file")
|
||||
}
|
||||
|
||||
// Edge case: If both stat calls fail, we don't update the modification times.
|
||||
// This prevents incorrectly treating a transient stat failure as "no change".
|
||||
// However, this means CheckCertificateModTime could return false positives on the next call
|
||||
// if either stat call fails again. Log a warning if this occurs.
|
||||
if certStatErr != nil && keyStatErr != nil {
|
||||
w.log.Warn().Msg("both cert and key stat failed during reload - mod times not updated, " +
|
||||
"next stat-based check may return false positives")
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
w.tlsCert = &cert
|
||||
if certInfo != nil {
|
||||
w.tlsCertModTime = certInfo.ModTime()
|
||||
}
|
||||
if keyInfo != nil {
|
||||
w.tlsKeyModTime = keyInfo.ModTime()
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
w.log.Debug().Str("cert", w.certPath).Str("key", w.keyPath).
|
||||
Msg("TLS certificate reloaded")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckCertificateModTime checks if certificate or key files have been modified since last load.
|
||||
// This is used as a fallback when inotify is not available.
|
||||
func (w *TlsConfigWatcher) CheckCertificateModTime() bool {
|
||||
certInfo, err := tlsFileStat(w.certPath)
|
||||
if err != nil {
|
||||
w.log.Error().Err(err).Str("cert", w.certPath).Msg("failed to stat certificate file")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
keyInfo, err := tlsFileStat(w.keyPath)
|
||||
if err != nil {
|
||||
w.log.Error().Err(err).Str("key", w.keyPath).Msg("failed to stat key file")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
w.mu.RLock()
|
||||
certModTime := w.tlsCertModTime
|
||||
keyModTime := w.tlsKeyModTime
|
||||
w.mu.RUnlock()
|
||||
|
||||
// Check if either file has been modified since last load
|
||||
if certInfo.ModTime().After(certModTime) || keyInfo.ModTime().After(keyModTime) {
|
||||
w.log.Debug().Msg("certificate or key file modification detected via stat")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
+1
-1
@@ -15,7 +15,7 @@ tests=("pushpull" "pushpull_authn" "delete_images" "referrers" "metadata" "anony
|
||||
"annotations" "detect_manifest_collision" "cve" "sync" "sync_docker" "sync_replica_cluster"
|
||||
"scrub" "garbage_collect" "metrics" "metrics_minimal" "multiarch_index" "docker_compat" "redis_local" "redis_session_store"
|
||||
"events_nats" "events_http" "events_nats_lint_failure" "events_http_lint_failure" "events_sink_failure" "events_config_decoding"
|
||||
"fips140" "fips140_authn" "openid_claim_mapping" "upgrade" "upgrade_minimal")
|
||||
"fips140" "fips140_authn" "openid_claim_mapping" "upgrade" "upgrade_minimal" "dynamic_tls")
|
||||
|
||||
for test in ${tests[*]}; do
|
||||
${BATS} ${BATS_FLAGS} ${SCRIPTPATH}/${test}.bats > ${test}.log & pids+=($!)
|
||||
|
||||
@@ -0,0 +1,275 @@
|
||||
# Note: Intended to be run as "make run-blackbox-tests" or "make run-blackbox-ci"
|
||||
# Makefile target installs & checks all necessary tooling
|
||||
# Extra tools that are not covered in Makefile target needs to be added in verify_prerequisites()
|
||||
|
||||
load helpers_zot
|
||||
load ../port_helper
|
||||
|
||||
function verify_prerequisites {
|
||||
if ! command -v curl >/dev/null 2>&1; then
|
||||
echo "you need to install curl as a prerequisite to running the tests" >&3
|
||||
return 1
|
||||
fi
|
||||
|
||||
if ! command -v jq >/dev/null 2>&1; then
|
||||
echo "you need to install jq as a prerequisite to running the tests" >&3
|
||||
return 1
|
||||
fi
|
||||
|
||||
if ! command -v openssl >/dev/null 2>&1; then
|
||||
echo "you need to install openssl as a prerequisite to running the tests" >&3
|
||||
return 1
|
||||
fi
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
# Generate a self-signed certificate with the given CN
|
||||
function generate_self_signed_cert() {
|
||||
local cert_path=${1}
|
||||
local key_path=${2}
|
||||
local common_name=${3:-"localhost"}
|
||||
local days=${4:-365}
|
||||
|
||||
openssl req -x509 -newkey rsa:2048 -keyout "${key_path}" -out "${cert_path}" \
|
||||
-days ${days} -nodes \
|
||||
-subj "/C=US/ST=Test/L=Test/O=Zot/CN=${common_name}"
|
||||
}
|
||||
|
||||
# Wait for a condition to be true, polling up to max_attempts times with interval_seconds between attempts
|
||||
# Usage: wait_for_condition <max_attempts> <interval_seconds> "<command>"
|
||||
# Returns 0 on success, 1 on timeout
|
||||
function wait_for_condition() {
|
||||
local max_attempts=${1}
|
||||
local interval=${2}
|
||||
local condition_cmd=${3}
|
||||
local attempt=1
|
||||
|
||||
while [ $attempt -le $max_attempts ]; do
|
||||
if eval "${condition_cmd}"; then
|
||||
echo "Condition met after $attempt attempts" >&3
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [ $attempt -lt $max_attempts ]; then
|
||||
sleep "${interval}"
|
||||
fi
|
||||
|
||||
((attempt++))
|
||||
done
|
||||
|
||||
echo "Condition timed out after $max_attempts attempts" >&3
|
||||
return 1
|
||||
}
|
||||
|
||||
function setup_file() {
|
||||
# Verify prerequisites are available
|
||||
if ! verify_prerequisites; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Download test data to folder common for the entire suite, not just this file
|
||||
skopeo --insecure-policy copy --format=oci docker://ghcr.io/project-zot/test-images/busybox:1.36 oci:${TEST_DATA_DIR}/busybox:1.36
|
||||
|
||||
# Setup zot server with TLS
|
||||
local zot_root_dir=${BATS_FILE_TMPDIR}/zot
|
||||
local zot_config_file=${BATS_FILE_TMPDIR}/zot_config.json
|
||||
local zot_cert_file=${BATS_FILE_TMPDIR}/server.cert
|
||||
local zot_key_file=${BATS_FILE_TMPDIR}/server.key
|
||||
zot_port=$(get_free_port_for_service "zot")
|
||||
echo ${zot_port} > ${BATS_FILE_TMPDIR}/zot.port
|
||||
|
||||
mkdir -p ${zot_root_dir}
|
||||
|
||||
# Generate initial TLS certificate
|
||||
generate_self_signed_cert "${zot_cert_file}" "${zot_key_file}" "127.0.0.1" 365
|
||||
|
||||
# Create zot config with TLS enabled
|
||||
cat > ${zot_config_file}<<EOF
|
||||
{
|
||||
"distSpecVersion":"1.1.1",
|
||||
"storage":{
|
||||
"dedupe": true,
|
||||
"gc": true,
|
||||
"gcDelay": "1h",
|
||||
"gcInterval": "6h",
|
||||
"rootDirectory": "${zot_root_dir}"
|
||||
},
|
||||
"http": {
|
||||
"address": "127.0.0.1",
|
||||
"port": "${zot_port}",
|
||||
"tls": {
|
||||
"cert": "${zot_cert_file}",
|
||||
"key": "${zot_key_file}"
|
||||
}
|
||||
},
|
||||
"log":{
|
||||
"level":"debug",
|
||||
"output": "${BATS_FILE_TMPDIR}/zot.log"
|
||||
}
|
||||
}
|
||||
EOF
|
||||
|
||||
echo ${zot_root_dir} >&3
|
||||
zot_serve ${ZOT_PATH} ${zot_config_file}
|
||||
|
||||
# Wait for server to be ready by polling for connectivity
|
||||
wait_for_condition 30 0.2 "curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog >/dev/null 2>&1"
|
||||
}
|
||||
|
||||
function teardown() {
|
||||
# conditionally printing on failure is possible from teardown but not from teardown_file
|
||||
cat ${BATS_FILE_TMPDIR}/zot.log 2>/dev/null || true
|
||||
}
|
||||
|
||||
function teardown_file() {
|
||||
zot_stop_all
|
||||
}
|
||||
|
||||
@test "TLS connection succeeds with self-signed certificate" {
|
||||
zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port)
|
||||
|
||||
# Test with curl using insecure flag since we're using self-signed cert
|
||||
run curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog
|
||||
[ "$status" -eq 0 ]
|
||||
}
|
||||
|
||||
@test "push image with TLS enabled" {
|
||||
zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port)
|
||||
|
||||
# Use skopeo to push image over HTTPS with insecure TLS verification
|
||||
run skopeo --insecure-policy copy --dest-tls-verify=false \
|
||||
oci:${TEST_DATA_DIR}/busybox:1.36 \
|
||||
docker://127.0.0.1:${zot_port}/busybox:1.36
|
||||
[ "$status" -eq 0 ]
|
||||
}
|
||||
|
||||
@test "pull image with TLS enabled" {
|
||||
zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port)
|
||||
local temp_oci_dir=${BATS_FILE_TMPDIR}/busybox-pulled
|
||||
|
||||
mkdir -p ${temp_oci_dir}
|
||||
|
||||
# Pull the pushed image back
|
||||
run skopeo --insecure-policy copy --src-tls-verify=false \
|
||||
docker://127.0.0.1:${zot_port}/busybox:1.36 \
|
||||
oci:${temp_oci_dir}
|
||||
[ "$status" -eq 0 ]
|
||||
|
||||
# Verify OCI image was downloaded
|
||||
[ -f "${temp_oci_dir}/oci-layout" ]
|
||||
}
|
||||
|
||||
@test "dynamic certificate reload: verify server uses new certificate after update" {
|
||||
zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port)
|
||||
local zot_cert_file=${BATS_FILE_TMPDIR}/server.cert
|
||||
local zot_key_file=${BATS_FILE_TMPDIR}/server.key
|
||||
|
||||
# Get the certificate fingerprint before update
|
||||
cert_fingerprint_before=$(openssl x509 -fingerprint -sha256 -noout -in "${zot_cert_file}" 2>/dev/null | cut -d'=' -f2)
|
||||
server_fingerprint_before=$(openssl s_client -connect 127.0.0.1:${zot_port} -servername 127.0.0.1 -showcerts </dev/null 2>/dev/null \
|
||||
| openssl x509 -fingerprint -sha256 -noout 2>/dev/null | cut -d'=' -f2)
|
||||
[ -n "${server_fingerprint_before}" ]
|
||||
|
||||
# Keep fetching catalog to ensure server is responsive before cert update
|
||||
wait_for_condition 10 0.5 "curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog >/dev/null 2>&1"
|
||||
|
||||
# Update the certificate with a new one
|
||||
# This simulates a real-world scenario where certificates are renewed
|
||||
generate_self_signed_cert "${zot_cert_file}" "${zot_key_file}" "127.0.0.1" 365
|
||||
|
||||
# Wait for file system changes to be visible and stat cache to expire
|
||||
# (allows time for inotify to detect changes or stat-based check to trigger)
|
||||
wait_for_condition 10 0.1 "[ \"$(openssl x509 -fingerprint -sha256 -noout -in \"${zot_cert_file}\" 2>/dev/null | cut -d'=' -f2)\" != \"${cert_fingerprint_before}\" ]"
|
||||
|
||||
# Request a new fingerprint after expecting the server to reload
|
||||
wait_for_condition 20 0.2 "[ \"$(openssl s_client -connect 127.0.0.1:${zot_port} -servername 127.0.0.1 -showcerts </dev/null 2>/dev/null | openssl x509 -fingerprint -sha256 -noout 2>/dev/null | cut -d'=' -f2)\" != \"${server_fingerprint_before}\" ]" || true
|
||||
|
||||
# Make several requests to ensure server picks up the new certificate
|
||||
# The server should automatically reload it through the GetCertificate callback
|
||||
server_fingerprint_after=""
|
||||
for i in {1..10}; do
|
||||
server_fingerprint_after=$(openssl s_client -connect 127.0.0.1:${zot_port} -servername 127.0.0.1 -showcerts </dev/null 2>/dev/null \
|
||||
| openssl x509 -fingerprint -sha256 -noout 2>/dev/null | cut -d'=' -f2)
|
||||
run curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog
|
||||
if [ "$status" -eq 0 ] && [ -n "${server_fingerprint_after}" ] && \
|
||||
[ "${server_fingerprint_before}" != "${server_fingerprint_after}" ]; then
|
||||
# Server is using the new certificate
|
||||
echo "Request $i succeeded with new certificate" >&3
|
||||
break
|
||||
fi
|
||||
if [ $i -lt 10 ]; then
|
||||
sleep 0.2
|
||||
fi
|
||||
done
|
||||
|
||||
[ -n "${server_fingerprint_after}" ]
|
||||
[ "${server_fingerprint_before}" != "${server_fingerprint_after}" ]
|
||||
[ "$status" -eq 0 ]
|
||||
}
|
||||
|
||||
@test "TLS works with multiple concurrent connections after certificate reload" {
|
||||
zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port)
|
||||
local zot_cert_file=${BATS_FILE_TMPDIR}/server.cert
|
||||
local zot_key_file=${BATS_FILE_TMPDIR}/server.key
|
||||
|
||||
# Regenerate certificate to trigger reload
|
||||
generate_self_signed_cert "${zot_cert_file}" "${zot_key_file}" "127.0.0.1" 365
|
||||
|
||||
# Wait for certificate to be reloaded by making requests
|
||||
wait_for_condition 20 0.2 "curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog >/dev/null 2>&1"
|
||||
|
||||
# Test multiple concurrent requests
|
||||
local failed=0
|
||||
local pids=()
|
||||
for i in {1..5}; do
|
||||
(curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog > /dev/null 2>&1) &
|
||||
pids+=($!)
|
||||
done
|
||||
|
||||
# Wait for all background requests to complete
|
||||
for pid in "${pids[@]}"; do
|
||||
if ! wait "$pid"; then
|
||||
failed=$((failed + 1))
|
||||
fi
|
||||
done
|
||||
[ "$failed" -eq 0 ]
|
||||
|
||||
# If any failed, the test will fail
|
||||
# Check that at least one request succeeds by making one more
|
||||
run curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog
|
||||
[ "$status" -eq 0 ]
|
||||
}
|
||||
|
||||
@test "certificate reload doesn't require server restart" {
|
||||
zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port)
|
||||
local zot_cert_file=${BATS_FILE_TMPDIR}/server.cert
|
||||
local zot_key_file=${BATS_FILE_TMPDIR}/server.key
|
||||
|
||||
# Get initial server PID
|
||||
local zot_pid=$(cat ${BATS_FILE_TMPDIR}/zot.pid | awk '{print $1}')
|
||||
|
||||
# Make a request to establish the server is running
|
||||
run curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog
|
||||
[ "$status" -eq 0 ]
|
||||
|
||||
# Verify server is still running with same PID
|
||||
kill -0 ${zot_pid} 2>/dev/null
|
||||
[ "$?" -eq 0 ]
|
||||
|
||||
# Update certificate multiple times
|
||||
for iteration in {1..3}; do
|
||||
generate_self_signed_cert "${zot_cert_file}" "${zot_key_file}" "127.0.0.1" 365
|
||||
|
||||
# Wait for server to reload the new certificate
|
||||
wait_for_condition 20 0.2 "curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog >/dev/null 2>&1" || true
|
||||
|
||||
# Server should still be running with the same PID
|
||||
kill -0 ${zot_pid} 2>/dev/null
|
||||
[ "$?" -eq 0 ]
|
||||
|
||||
# Requests should still work
|
||||
run curl -k --max-time 5 --connect-timeout 3 https://127.0.0.1:${zot_port}/v2/_catalog
|
||||
[ "$status" -eq 0 ]
|
||||
done
|
||||
}
|
||||
@@ -448,5 +448,11 @@
|
||||
"begin": 11510,
|
||||
"end": 11519
|
||||
}
|
||||
},
|
||||
"blackbox/dynamic_tls.bats": {
|
||||
"zot": {
|
||||
"begin": 11520,
|
||||
"end": 11529
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user