Fix resource leaks and race conditions in TLS cert reloader

Co-authored-by: rchincha <45800463+rchincha@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-02-01 10:28:16 +00:00
parent ceec4e7702
commit 2f1891ce56
3 changed files with 38 additions and 13 deletions
+9
View File
@@ -54,6 +54,7 @@ type Controller struct {
HTPasswd *HTPasswd
HTPasswdWatcher *HTPasswdWatcher
LDAPClient *LDAPClient
CertReloader *CertReloader
taskScheduler *scheduler.Scheduler
Healthz *common.Healthz
// runtime params
@@ -213,6 +214,9 @@ func (c *Controller) Run() error {
return err
}
// Store the CertReloader so it can be closed during shutdown
c.CertReloader = certReloader
// These are the same as the cipher suites in defaultCipherSuitesFIPS for TLS 1.2
// see https://cs.opensource.google/go/go/+/refs/tags/go1.24.9:src/crypto/tls/defaults.go;l=123
// Note: Order doesn't matter - Go 1.17+ automatically orders cipher suites based on
@@ -495,6 +499,11 @@ func (c *Controller) StopBackgroundTasks() {
if c.HTPasswdWatcher != nil {
_ = c.HTPasswdWatcher.Close()
}
// Close CertReloader to prevent resource leaks
if c.CertReloader != nil {
_ = c.CertReloader.Close()
}
}
func (c *Controller) StartBackgroundTasks() {
+22 -13
View File
@@ -34,6 +34,7 @@ type CertReloader struct {
lastCheck time.Time
checkCache time.Duration // Minimum time between file stat checks
stopWatcher chan struct{}
closeOnce sync.Once // Ensures Close() can be called multiple times safely
}
// NewCertReloader creates a new certificate reloader and loads the initial certificate.
@@ -57,20 +58,28 @@ func NewCertReloader(certPath, keyPath string, logger log.Logger) (*CertReloader
logger.Warn().Err(err).Msg("failed to start fsnotify watcher, falling back to periodic checking")
}
// NOTE: Do not add initialization that can fail after this point without ensuring
// the watcher is stopped (e.g., by calling Close on error), otherwise the
// watchLoop goroutine started by startWatcher could be leaked.
return reloader, nil
}
// Close stops the file watcher and releases resources.
// This method is safe to call multiple times.
func (cr *CertReloader) Close() error {
if cr.stopWatcher != nil {
close(cr.stopWatcher)
}
var err error
cr.closeOnce.Do(func() {
if cr.stopWatcher != nil {
close(cr.stopWatcher)
}
if cr.watcher != nil {
return cr.watcher.Close()
}
if cr.watcher != nil {
err = cr.watcher.Close()
}
})
return nil
return err
}
// startWatcher initializes the fsnotify watcher for certificate files.
@@ -88,12 +97,14 @@ func (cr *CertReloader) startWatcher() error {
keyDir := filepath.Dir(cr.keyPath)
if err := watcher.Add(certDir); err != nil {
watcher.Close()
return err
}
// If cert and key are in different directories, watch both
if certDir != keyDir {
if err := watcher.Add(keyDir); err != nil {
watcher.Close()
return err
}
}
@@ -189,17 +200,15 @@ func (cr *CertReloader) reload() error {
// Uses time-based caching to avoid excessive file system calls.
func (cr *CertReloader) maybeReload() error {
// Use time-based cache to reduce frequency of stat calls
cr.certMu.RLock()
// Check and update lastCheck within the same critical section to avoid race conditions
cr.certMu.Lock()
if time.Since(cr.lastCheck) < cr.checkCache {
// Recently checked, skip stat calls
cr.certMu.RUnlock()
cr.certMu.Unlock()
return nil
}
cr.certMu.RUnlock()
// Update last check time
cr.certMu.Lock()
// Update last check time within the same critical section as the cache check
cr.lastCheck = time.Now()
cr.certMu.Unlock()
+7
View File
@@ -129,6 +129,7 @@ func TestCertReloaderDirectly(t *testing.T) {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
So(reloader, ShouldNotBeNil)
defer reloader.Close()
// Get certificate via callback
getCert := reloader.GetCertificateFunc()
@@ -140,6 +141,7 @@ func TestCertReloaderDirectly(t *testing.T) {
Convey("GetCertificateFunc should reload when certificate changes", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
initialCert, err := getCert(nil)
@@ -182,6 +184,7 @@ func TestCertReloaderDirectly(t *testing.T) {
Convey("GetCertificateFunc should handle missing files gracefully", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
@@ -198,6 +201,7 @@ func TestCertReloaderDirectly(t *testing.T) {
Convey("GetCertificateFunc should handle only cert file modification", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
initialCert, err := getCert(nil)
@@ -235,6 +239,7 @@ func TestCertReloaderDirectly(t *testing.T) {
Convey("GetCertificateFunc should handle concurrent access", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
@@ -267,6 +272,7 @@ func TestCertReloaderDirectly(t *testing.T) {
Convey("GetCertificateFunc should not reload if files haven't changed", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
@@ -291,6 +297,7 @@ func TestCertReloaderDirectly(t *testing.T) {
Convey("GetCertificateFunc should reload when key file changes", func() {
reloader, err := api.NewCertReloader(certPath, keyPath, log.NewTestLogger())
So(err, ShouldBeNil)
defer reloader.Close()
getCert := reloader.GetCertificateFunc()
initialCert, err := getCert(nil)