mirror of
https://github.com/project-zot/zot.git
synced 2026-06-19 22:27:58 +08:00
Fix resource leaks and race conditions in TLS cert reloader
Co-authored-by: rchincha <45800463+rchincha@users.noreply.github.com>
This commit is contained in:
@@ -54,6 +54,7 @@ type Controller struct {
|
||||
HTPasswd *HTPasswd
|
||||
HTPasswdWatcher *HTPasswdWatcher
|
||||
LDAPClient *LDAPClient
|
||||
CertReloader *CertReloader
|
||||
taskScheduler *scheduler.Scheduler
|
||||
Healthz *common.Healthz
|
||||
// runtime params
|
||||
@@ -213,6 +214,9 @@ func (c *Controller) Run() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the CertReloader so it can be closed during shutdown
|
||||
c.CertReloader = certReloader
|
||||
|
||||
// These are the same as the cipher suites in defaultCipherSuitesFIPS for TLS 1.2
|
||||
// see https://cs.opensource.google/go/go/+/refs/tags/go1.24.9:src/crypto/tls/defaults.go;l=123
|
||||
// Note: Order doesn't matter - Go 1.17+ automatically orders cipher suites based on
|
||||
@@ -495,6 +499,11 @@ func (c *Controller) StopBackgroundTasks() {
|
||||
if c.HTPasswdWatcher != nil {
|
||||
_ = c.HTPasswdWatcher.Close()
|
||||
}
|
||||
|
||||
// Close CertReloader to prevent resource leaks
|
||||
if c.CertReloader != nil {
|
||||
_ = c.CertReloader.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) StartBackgroundTasks() {
|
||||
|
||||
+22
-13
@@ -34,6 +34,7 @@ type CertReloader struct {
|
||||
lastCheck time.Time
|
||||
checkCache time.Duration // Minimum time between file stat checks
|
||||
stopWatcher chan struct{}
|
||||
closeOnce sync.Once // Ensures Close() can be called multiple times safely
|
||||
}
|
||||
|
||||
// NewCertReloader creates a new certificate reloader and loads the initial certificate.
|
||||
@@ -57,20 +58,28 @@ func NewCertReloader(certPath, keyPath string, logger log.Logger) (*CertReloader
|
||||
logger.Warn().Err(err).Msg("failed to start fsnotify watcher, falling back to periodic checking")
|
||||
}
|
||||
|
||||
// NOTE: Do not add initialization that can fail after this point without ensuring
|
||||
// the watcher is stopped (e.g., by calling Close on error), otherwise the
|
||||
// watchLoop goroutine started by startWatcher could be leaked.
|
||||
|
||||
return reloader, nil
|
||||
}
|
||||
|
||||
// Close stops the file watcher and releases resources.
|
||||
// This method is safe to call multiple times.
|
||||
func (cr *CertReloader) Close() error {
|
||||
if cr.stopWatcher != nil {
|
||||
close(cr.stopWatcher)
|
||||
}
|
||||
var err error
|
||||
cr.closeOnce.Do(func() {
|
||||
if cr.stopWatcher != nil {
|
||||
close(cr.stopWatcher)
|
||||
}
|
||||
|
||||
if cr.watcher != nil {
|
||||
return cr.watcher.Close()
|
||||
}
|
||||
if cr.watcher != nil {
|
||||
err = cr.watcher.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// startWatcher initializes the fsnotify watcher for certificate files.
|
||||
@@ -88,12 +97,14 @@ func (cr *CertReloader) startWatcher() error {
|
||||
keyDir := filepath.Dir(cr.keyPath)
|
||||
|
||||
if err := watcher.Add(certDir); err != nil {
|
||||
watcher.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// If cert and key are in different directories, watch both
|
||||
if certDir != keyDir {
|
||||
if err := watcher.Add(keyDir); err != nil {
|
||||
watcher.Close()
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -189,17 +200,15 @@ func (cr *CertReloader) reload() error {
|
||||
// Uses time-based caching to avoid excessive file system calls.
|
||||
func (cr *CertReloader) maybeReload() error {
|
||||
// Use time-based cache to reduce frequency of stat calls
|
||||
cr.certMu.RLock()
|
||||
// Check and update lastCheck within the same critical section to avoid race conditions
|
||||
cr.certMu.Lock()
|
||||
if time.Since(cr.lastCheck) < cr.checkCache {
|
||||
// Recently checked, skip stat calls
|
||||
cr.certMu.RUnlock()
|
||||
cr.certMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
cr.certMu.RUnlock()
|
||||
|
||||
// Update last check time
|
||||
cr.certMu.Lock()
|
||||
// Update last check time within the same critical section as the cache check
|
||||
cr.lastCheck = time.Now()
|
||||
cr.certMu.Unlock()
|
||||
|
||||
|
||||
@@ -129,6 +129,7 @@ func TestCertReloaderDirectly(t *testing.T) {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
So(reloader, ShouldNotBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
// Get certificate via callback
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
@@ -140,6 +141,7 @@ func TestCertReloaderDirectly(t *testing.T) {
|
||||
Convey("GetCertificateFunc should reload when certificate changes", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
initialCert, err := getCert(nil)
|
||||
@@ -182,6 +184,7 @@ func TestCertReloaderDirectly(t *testing.T) {
|
||||
Convey("GetCertificateFunc should handle missing files gracefully", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
|
||||
@@ -198,6 +201,7 @@ func TestCertReloaderDirectly(t *testing.T) {
|
||||
Convey("GetCertificateFunc should handle only cert file modification", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
initialCert, err := getCert(nil)
|
||||
@@ -235,6 +239,7 @@ func TestCertReloaderDirectly(t *testing.T) {
|
||||
Convey("GetCertificateFunc should handle concurrent access", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
|
||||
@@ -267,6 +272,7 @@ func TestCertReloaderDirectly(t *testing.T) {
|
||||
Convey("GetCertificateFunc should not reload if files haven't changed", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
|
||||
@@ -291,6 +297,7 @@ func TestCertReloaderDirectly(t *testing.T) {
|
||||
Convey("GetCertificateFunc should reload when key file changes", func() {
|
||||
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
|
||||
So(err, ShouldBeNil)
|
||||
defer reloader.Close()
|
||||
|
||||
getCert := reloader.GetCertificateFunc()
|
||||
initialCert, err := getCert(nil)
|
||||
|
||||
Reference in New Issue
Block a user