From dfb5d1df5403186d2359d95cb24c8e2f17ad930c Mon Sep 17 00:00:00 2001 From: Andrei Aaron Date: Sat, 18 Oct 2025 11:20:58 +0300 Subject: [PATCH] fix: make config read/write thread safe (#3432) * fix: make config read/write thread safe and fix some other similar issues 1. The config config has a lock, and safe methods to update and read the attributes 2. The config has methods to retrieve copies of specific attributes, such as the extyensions config, the auth config, and the authz config. These are needed, as the config object may mutate in the middle of an auth/authz requests, and we avoid partial configuration being applied for that request. 3. Fix an issue with the monitoring server not stopping when the controller is shut down. 4. Fix an issue with the HTPasswdWatcher not stopping when the background tasks are supposed to finish. 5. Fix some tests using hardcoded ports. Moved some of the methods which were on the main config to the auth, access control and extension configs Signed-off-by: Andrei Aaron --- go.mod | 1 + go.sum | 2 + pkg/api/authn.go | 83 +- pkg/api/authz.go | 64 +- pkg/api/config/config.go | 813 ++++- pkg/api/config/config_test.go | 3278 +++++++++++++++++- pkg/api/controller.go | 226 +- pkg/api/controller_test.go | 22 +- pkg/api/htpasswd.go | 167 +- pkg/api/htpasswd_test.go | 493 ++- pkg/api/proxy.go | 23 +- pkg/api/routes.go | 55 +- pkg/cli/server/config_reloader.go | 7 +- pkg/cli/server/extensions_test.go | 2 +- pkg/cli/server/root.go | 223 +- pkg/common/http_server.go | 15 +- pkg/extensions/config/config.go | 168 +- pkg/extensions/config/config_test.go | 564 +++ pkg/extensions/extension_events.go | 9 +- pkg/extensions/extension_events_disabled.go | 4 +- pkg/extensions/extension_image_trust.go | 14 +- pkg/extensions/extension_metrics.go | 26 +- pkg/extensions/extension_mgmt.go | 3 +- pkg/extensions/extension_scrub.go | 25 +- pkg/extensions/extension_search.go | 21 +- pkg/extensions/extension_sync.go | 33 +- pkg/extensions/extension_ui.go | 3 +- pkg/extensions/extension_userprefs.go | 3 +- pkg/extensions/extensions_test.go | 3 +- pkg/extensions/get_extensions.go | 11 +- pkg/extensions/monitoring/common.go | 2 + pkg/extensions/monitoring/extension.go | 5 + pkg/extensions/monitoring/minimal.go | 24 +- pkg/extensions/monitoring/monitoring_test.go | 9 +- pkg/extensions/search/search_test.go | 8 +- pkg/extensions/sync/sync_internal_test.go | 4 - pkg/extensions/sync/sync_test.go | 33 +- pkg/scheduler/scheduler.go | 5 +- pkg/storage/gc/gc_test.go | 3 + pkg/test/common/utils.go | 96 + pkg/test/common/utils_test.go | 140 + 41 files changed, 6029 insertions(+), 661 deletions(-) create mode 100644 pkg/extensions/config/config_test.go diff --git a/go.mod b/go.mod index 74e91f05..b4448533 100644 --- a/go.mod +++ b/go.mod @@ -69,6 +69,7 @@ require ( github.com/stretchr/testify v1.11.1 github.com/swaggo/http-swagger v1.3.4 github.com/swaggo/swag v1.16.6 + github.com/tiendc/go-deepcopy v1.7.1 github.com/vektah/gqlparser/v2 v2.5.30 github.com/zitadel/oidc/v3 v3.45.0 go.etcd.io/bbolt v1.4.3 diff --git a/go.sum b/go.sum index 14e82bef..026ec3b3 100644 --- a/go.sum +++ b/go.sum @@ -2063,6 +2063,8 @@ github.com/theupdateframework/go-tuf/v2 v2.1.1 h1:OWcoHItwsGO+7m0wLa7FDWPR4oB1cj github.com/theupdateframework/go-tuf/v2 v2.1.1/go.mod h1:V675cQGhZONR0OGQ8r1feO0uwtsTBYPDWHzAAPn5rjE= github.com/theupdateframework/notary v0.7.0 h1:QyagRZ7wlSpjT5N2qQAh/pN+DVqgekv4DzbAiAiEL3c= github.com/theupdateframework/notary v0.7.0/go.mod h1:c9DRxcmhHmVLDay4/2fUYdISnHqbFDGRSlXPO0AhYWw= +github.com/tiendc/go-deepcopy v1.7.1 h1:LnubftI6nYaaMOcaz0LphzwraqN8jiWTwm416sitff4= +github.com/tiendc/go-deepcopy v1.7.1/go.mod h1:4bKjNC2r7boYOkD2IOuZpYjmlDdzjbpTRyCx+goBCJQ= github.com/tink-crypto/tink-go-awskms/v2 v2.1.0 h1:N9UxlsOzu5mttdjhxkDLbzwtEecuXmlxZVo/ds7JKJI= github.com/tink-crypto/tink-go-awskms/v2 v2.1.0/go.mod h1:PxSp9GlOkKL9rlybW804uspnHuO9nbD98V/fDX4uSis= github.com/tink-crypto/tink-go-gcpkms/v2 v2.2.0 h1:3B9i6XBXNTRspfkTC0asN5W0K6GhOSgcujNiECNRNb0= diff --git a/pkg/api/authn.go b/pkg/api/authn.go index 7a6c7aed..cb1b3a84 100644 --- a/pkg/api/authn.go +++ b/pkg/api/authn.go @@ -55,7 +55,8 @@ func AuthHandler(ctlr *Controller) mux.MiddlewareFunc { log: ctlr.Log, } - if ctlr.Config.IsBearerAuthEnabled() { + authConfig := ctlr.Config.CopyAuthConfig() + if authConfig.IsBearerAuthEnabled() { return bearerAuthHandler(ctlr) } @@ -103,6 +104,12 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce ) (bool, error) { cookieStore := ctlr.CookieStore + // Get auth config once to avoid multiple calls + authConfig := ctlr.Config.CopyAuthConfig() + if authConfig == nil { + return false, nil + } + identity, passphrase, err := getUsernamePasswordBasicAuth(request) if err != nil { ctlr.Log.Error().Err(err).Msg("failed to parse authorization header") @@ -116,7 +123,8 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce // Process request var groups []string - if ctlr.Config.HTTP.AccessControl != nil { + accessControl := ctlr.Config.CopyAccessControlConfig() + if accessControl != nil { ac := NewAccessController(ctlr.Config) groups = ac.getUserGroups(identity) } @@ -145,13 +153,14 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce } // next, LDAP if configured (network-based which can lose connectivity) - if ctlr.Config.HTTP.Auth != nil && ctlr.Config.HTTP.Auth.LDAP != nil { + if authConfig.IsLdapAuthEnabled() { ok, _, ldapgroups, err := amw.ldapClient.Authenticate(identity, passphrase) if ok && err == nil { // Process request var groups []string - if ctlr.Config.HTTP.AccessControl != nil { + accessControl := ctlr.Config.CopyAccessControlConfig() + if accessControl != nil { ac := NewAccessController(ctlr.Config) groups = ac.getUserGroups(identity) } @@ -181,7 +190,7 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce } // last try API keys - if ctlr.Config.IsAPIKeyEnabled() { + if authConfig.IsAPIKeyEnabled() { apiKey := passphrase if !strings.HasPrefix(apiKey, constants.APIKeysPrefix) { @@ -248,16 +257,20 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce } func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFunc { //nolint: gocyclo + // Get auth config once to avoid multiple calls + authConfig := ctlr.Config.CopyAuthConfig() + // no password based authN, if neither LDAP nor HTTP BASIC is enabled - if !ctlr.Config.IsBasicAuthnEnabled() { + if !authConfig.IsBasicAuthnEnabled() { return noPasswdAuth(ctlr) } - delay := ctlr.Config.HTTP.Auth.FailDelay + delay := authConfig.GetFailDelay() + realm := ctlr.Config.GetRealm() // ldap and htpasswd based authN - if ctlr.Config.IsLdapAuthEnabled() { - ldapConfig := ctlr.Config.HTTP.Auth.LDAP + if authConfig.IsLdapAuthEnabled() { + ldapConfig := authConfig.LDAP ctlr.LDAPClient = &LDAPClient{ Host: ldapConfig.Address, @@ -278,17 +291,17 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun amw.ldapClient = ctlr.LDAPClient - if ctlr.Config.HTTP.Auth.LDAP.CACert != "" { - caCert, err := os.ReadFile(ctlr.Config.HTTP.Auth.LDAP.CACert) + if authConfig.LDAP.CACert != "" { + caCert, err := os.ReadFile(authConfig.LDAP.CACert) if err != nil { - amw.log.Panic().Err(err).Str("caCert", ctlr.Config.HTTP.Auth.LDAP.CACert). + amw.log.Panic().Err(err).Str("caCert", authConfig.LDAP.CACert). Msg("failed to read caCert") } caCertPool := x509.NewCertPool() if !caCertPool.AppendCertsFromPEM(caCert) { - amw.log.Panic().Err(zerr.ErrBadCACert).Str("caCert", ctlr.Config.HTTP.Auth.LDAP.CACert). + amw.log.Panic().Err(zerr.ErrBadCACert).Str("caCert", authConfig.LDAP.CACert). Msg("failed to read caCert") } @@ -297,7 +310,7 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun // default to system cert pool caCertPool, err := x509.SystemCertPool() if err != nil { - amw.log.Panic().Err(zerr.ErrBadCACert).Str("caCert", ctlr.Config.HTTP.Auth.LDAP.CACert). + amw.log.Panic().Err(zerr.ErrBadCACert).Str("caCert", authConfig.LDAP.CACert). Msg("failed to get system cert pool") } @@ -305,26 +318,26 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun } } - if ctlr.Config.IsHtpasswdAuthEnabled() { - err := amw.htpasswd.Reload(ctlr.Config.HTTP.Auth.HTPasswd.Path) + if authConfig.IsHtpasswdAuthEnabled() { + err := amw.htpasswd.Reload(authConfig.HTPasswd.Path) if err != nil { - amw.log.Panic().Err(err).Str("credsFile", ctlr.Config.HTTP.Auth.HTPasswd.Path). + amw.log.Panic().Err(err).Str("credsFile", authConfig.HTPasswd.Path). Msg("failed to open creds-file") } } // openid based authN - if ctlr.Config.IsOpenIDAuthEnabled() { + if authConfig.IsOpenIDAuthEnabled() { ctlr.RelyingParties = make(map[string]rp.RelyingParty) - for provider := range ctlr.Config.HTTP.Auth.OpenID.Providers { + for provider := range authConfig.OpenID.Providers { if config.IsOpenIDSupported(provider) { - rp := NewRelyingPartyOIDC(context.TODO(), ctlr.Config, provider, ctlr.Config.HTTP.Auth.SessionHashKey, - ctlr.Config.HTTP.Auth.SessionEncryptKey, ctlr.Log) + rp := NewRelyingPartyOIDC(context.TODO(), ctlr.Config, provider, authConfig.SessionHashKey, + authConfig.SessionEncryptKey, ctlr.Log) ctlr.RelyingParties[provider] = rp } else if config.IsOauth2Supported(provider) { - rp := NewRelyingPartyGithub(ctlr.Config, provider, ctlr.Config.HTTP.Auth.SessionHashKey, - ctlr.Config.HTTP.Auth.SessionEncryptKey, ctlr.Log) + rp := NewRelyingPartyGithub(ctlr.Config, provider, authConfig.SessionHashKey, + authConfig.SessionEncryptKey, ctlr.Log) ctlr.RelyingParties[provider] = rp } } @@ -340,7 +353,10 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun } isMgmtRequested := request.RequestURI == constants.FullMgmt - allowAnonymous := ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() + + // Get access control config safely + accessControlConfig := ctlr.Config.CopyAccessControlConfig() + allowAnonymous := accessControlConfig != nil && accessControlConfig.AnonymousPolicyExists() // build user access control info userAc := reqCtx.NewUserAccessControl() @@ -370,7 +386,7 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun if errors.Is(err, zerr.ErrUserDataNotFound) { ctlr.Log.Err(err).Msg("failed to find user profile in DB") - authFail(response, request, ctlr.Config.HTTP.Realm, delay) + authFail(response, request, realm, delay) } response.WriteHeader(http.StatusInternalServerError) @@ -397,22 +413,25 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun return } - authFail(response, request, ctlr.Config.HTTP.Realm, delay) + authFail(response, request, realm, delay) }) } } func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc { + // Get auth config safely + authConfig := ctlr.Config.CopyAuthConfig() + // although the configuration option is called 'cert', this function will also parse a public key directly // see https://github.com/project-zot/zot/issues/3173 for info - publicKey, err := loadPublicKeyFromFile(ctlr.Config.HTTP.Auth.Bearer.Cert) + publicKey, err := loadPublicKeyFromFile(authConfig.Bearer.Cert) if err != nil { ctlr.Log.Panic().Err(err).Msg("failed to load public key for bearer authentication") } authorizer := NewBearerAuthorizer( - ctlr.Config.HTTP.Auth.Bearer.Realm, - ctlr.Config.HTTP.Auth.Bearer.Service, + authConfig.Bearer.Realm, + authConfig.Bearer.Service, publicKey, ) @@ -509,7 +528,11 @@ func noPasswdAuth(ctlr *Controller) mux.MiddlewareFunc { } if ctlr.Config.IsMTLSAuthEnabled() && userAc.IsAnonymous() { - authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + authConfig := ctlr.Config.CopyAuthConfig() + failDelay := authConfig.GetFailDelay() + realm := ctlr.Config.GetRealm() + + authFail(response, request, realm, failDelay) return } diff --git a/pkg/api/authz.go b/pkg/api/authz.go index 81292664..0c6475fa 100644 --- a/pkg/api/authz.go +++ b/pkg/api/authz.go @@ -42,16 +42,20 @@ type AccessController struct { } func NewAccessController(conf *config.Config) *AccessController { - if conf.HTTP.AccessControl == nil { + // Get access control config safely + accessControlConfig := conf.CopyAccessControlConfig() + logConfig := conf.CopyLogConfig() + + if accessControlConfig == nil { return &AccessController{ Config: &config.AccessControlConfig{}, - Log: log.NewLogger(conf.Log.Level, conf.Log.Output), + Log: log.NewLogger(logConfig.Level, logConfig.Output), } } return &AccessController{ - Config: conf.HTTP.AccessControl, - Log: log.NewLogger(conf.Log.Level, conf.Log.Output), + Config: accessControlConfig, + Log: log.NewLogger(logConfig.Level, logConfig.Output), } } @@ -117,14 +121,17 @@ func (ac *AccessController) can(userAc *reqCtx.UserAccessControl, action, reposi username := userAc.GetUsername() // check matched repo based policy - pg, ok := ac.Config.Repositories[longestMatchedPattern] + repositories := ac.Config.GetRepositories() + pg, ok := repositories[longestMatchedPattern] + if ok { can = ac.isPermitted(userGroups, username, action, pg) } // check admins based policy if !can { - if ac.isAdmin(username, userGroups) && common.Contains(ac.Config.AdminPolicy.Actions, action) { + adminPolicy := ac.Config.GetAdminPolicy() + if ac.isAdmin(username, userGroups) && common.Contains(adminPolicy.Actions, action) { can = true } } @@ -134,7 +141,8 @@ func (ac *AccessController) can(userAc *reqCtx.UserAccessControl, action, reposi // isAdmin . func (ac *AccessController) isAdmin(username string, userGroups []string) bool { - if common.Contains(ac.Config.AdminPolicy.Users, username) || ac.isAnyGroupInAdminPolicy(userGroups) { + adminPolicy := ac.Config.GetAdminPolicy() + if common.Contains(adminPolicy.Users, username) || ac.isAnyGroupInAdminPolicy(userGroups) { return true } @@ -142,8 +150,9 @@ func (ac *AccessController) isAdmin(username string, userGroups []string) bool { } func (ac *AccessController) isAnyGroupInAdminPolicy(userGroups []string) bool { + adminPolicy := ac.Config.GetAdminPolicy() for _, group := range userGroups { - if common.Contains(ac.Config.AdminPolicy.Groups, group) { + if common.Contains(adminPolicy.Groups, group) { return true } } @@ -154,7 +163,8 @@ func (ac *AccessController) isAnyGroupInAdminPolicy(userGroups []string) bool { func (ac *AccessController) getUserGroups(username string) []string { var groupNames []string - for groupName, group := range ac.Config.Groups { + groups := ac.Config.GetGroups() + for groupName, group := range groups { for _, user := range group.Users { // find if the user is part of any groups if user == username { @@ -241,6 +251,11 @@ func (ac *AccessController) isPermitted(userGroups []string, username, action st func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + // Get configs safely + authConfig := ctlr.Config.CopyAuthConfig() + realm := ctlr.Config.GetRealm() + failDelay := authConfig.GetFailDelay() + /* NOTE: since we only do READ actions in extensions, this middleware is enough for them because it populates the context with user relevant data to be processed by each individual extension @@ -271,7 +286,7 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { // get access control context made in authn.go userAc, err := reqCtx.UserAcFromContext(request.Context()) if err != nil { // should never happen - authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + authFail(response, request, realm, failDelay) return } @@ -287,6 +302,11 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + // Get configs safely + authConfig := ctlr.Config.CopyAuthConfig() + realm := ctlr.Config.GetRealm() + failDelay := authConfig.GetFailDelay() + if request.Method == http.MethodOptions { next.ServeHTTP(response, request) @@ -310,7 +330,7 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { // get userAc built in authn and previous authz middlewares userAc, err := reqCtx.UserAcFromContext(request.Context()) if err != nil { // should never happen - authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + authFail(response, request, realm, failDelay) return } @@ -341,7 +361,7 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { can := acCtrlr.can(userAc, action, resource) //nolint:contextcheck if !can { - common.AuthzFail(response, request, userAc.GetUsername(), ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + common.AuthzFail(response, request, userAc.GetUsername(), realm, failDelay) } else { next.ServeHTTP(response, request) //nolint:contextcheck } @@ -352,17 +372,25 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { func MetricsAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - if ctlr.Config.HTTP.AccessControl == nil { + // Get configs safely + authConfig := ctlr.Config.CopyAuthConfig() + realm := ctlr.Config.GetRealm() + failDelay := authConfig.GetFailDelay() + + accessControlConfig := ctlr.Config.CopyAccessControlConfig() + + if accessControlConfig == nil { // allow access to authenticated user as anonymous policy does not exist next.ServeHTTP(response, request) return } - if len(ctlr.Config.HTTP.AccessControl.Metrics.Users) == 0 { + metricsConfig := accessControlConfig.GetMetrics() + if len(metricsConfig.Users) == 0 { log := ctlr.Log log.Warn().Msg("auth is enabled but no metrics users in accessControl: /metrics is unaccesible") - common.AuthzFail(response, request, "", ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + common.AuthzFail(response, request, "", realm, failDelay) return } @@ -370,14 +398,14 @@ func MetricsAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { // get access control context made in authn.go userAc, err := reqCtx.UserAcFromContext(request.Context()) if err != nil { // should never happen - common.AuthzFail(response, request, "", ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + common.AuthzFail(response, request, "", realm, failDelay) return } username := userAc.GetUsername() - if !common.Contains(ctlr.Config.HTTP.AccessControl.Metrics.Users, username) { - common.AuthzFail(response, request, username, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + if !common.Contains(metricsConfig.Users, username) { + common.AuthzFail(response, request, username, realm, failDelay) return } diff --git a/pkg/api/config/config.go b/pkg/api/config/config.go index c5c9f360..9bb9a469 100644 --- a/pkg/api/config/config.go +++ b/pkg/api/config/config.go @@ -3,9 +3,11 @@ package config import ( "encoding/json" "os" + "sync" "time" distspec "github.com/opencontainers/distribution-spec/specs-go" + "github.com/tiendc/go-deepcopy" "zotregistry.dev/zot/v2/pkg/compat" extconf "zotregistry.dev/zot/v2/pkg/extensions/config" @@ -80,6 +82,59 @@ type AuthConfig struct { SessionDriver map[string]any `mapstructure:",omitempty"` } +// IsLdapAuthEnabled checks if LDAP authentication is enabled in this auth config. +func (a *AuthConfig) IsLdapAuthEnabled() bool { + return a != nil && a.LDAP != nil +} + +// IsHtpasswdAuthEnabled checks if HTPasswd authentication is enabled in this auth config. +func (a *AuthConfig) IsHtpasswdAuthEnabled() bool { + return a != nil && a.HTPasswd.Path != "" +} + +// IsBearerAuthEnabled checks if Bearer authentication is enabled in this auth config. +func (a *AuthConfig) IsBearerAuthEnabled() bool { + return a != nil && a.Bearer != nil && a.Bearer.Cert != "" && a.Bearer.Realm != "" && a.Bearer.Service != "" +} + +// IsOpenIDAuthEnabled checks if OpenID authentication is enabled in this auth config. +func (a *AuthConfig) IsOpenIDAuthEnabled() bool { + if a == nil || a.OpenID == nil { + return false + } + + for provider := range a.OpenID.Providers { + if IsOpenIDSupported(provider) || IsOauth2Supported(provider) { + return true + } + } + + return false +} + +// IsAPIKeyEnabled checks if API Key authentication is enabled in this auth config. +func (a *AuthConfig) IsAPIKeyEnabled() bool { + return a != nil && a.APIKey +} + +// IsBasicAuthnEnabled checks if any basic authentication method is enabled in this auth config. +func (a *AuthConfig) IsBasicAuthnEnabled() bool { + if a == nil { + return false + } + + return a.IsHtpasswdAuthEnabled() || a.IsLdapAuthEnabled() || a.IsOpenIDAuthEnabled() || a.IsAPIKeyEnabled() +} + +// GetFailDelay returns the configured fail delay for authentication attempts. +func (a *AuthConfig) GetFailDelay() int { + if a == nil { + return 0 + } + + return a.FailDelay +} + type BearerConfig struct { Realm string Service string @@ -156,6 +211,11 @@ type ClusterConfig struct { Proxy *ClusterRequestProxyConfig `json:"-" mapstructure:"-"` } +// IsClustered returns true if the cluster configuration represents a multi-node cluster. +func (c *ClusterConfig) IsClustered() bool { + return c != nil && len(c.Members) > 1 +} + type ClusterRequestProxyConfig struct { // holds the cluster socket (IP:port) derived from the host's // interface configuration and the listening port of the HTTP server. @@ -187,20 +247,34 @@ type LDAPConfig struct { } func (ldapConf *LDAPConfig) BindDN() string { + if ldapConf == nil { + return "" + } + return ldapConf.bindDN } func (ldapConf *LDAPConfig) SetBindDN(bindDN string) *LDAPConfig { + if ldapConf == nil { + return nil + } ldapConf.bindDN = bindDN return ldapConf } func (ldapConf *LDAPConfig) BindPassword() string { + if ldapConf == nil { + return "" + } + return ldapConf.bindPassword } func (ldapConf *LDAPConfig) SetBindPassword(bindPassword string) *LDAPConfig { + if ldapConf == nil { + return nil + } ldapConf.bindPassword = bindPassword return ldapConf @@ -224,6 +298,11 @@ type AccessControlConfig struct { Metrics Metrics } +// IsAuthzEnabled checks if authorization is enabled (access control is configured). +func (config *AccessControlConfig) IsAuthzEnabled() bool { + return config != nil +} + func (config *AccessControlConfig) AnonymousPolicyExists() bool { if config == nil { return false @@ -238,6 +317,89 @@ func (config *AccessControlConfig) AnonymousPolicyExists() bool { return false } +// ContainsOnlyAnonymousPolicy checks if the access control configuration contains only anonymous policies. +func (config *AccessControlConfig) ContainsOnlyAnonymousPolicy() bool { + if config == nil { + return true + } + + // Check if admin policy has any actions or users + if len(config.AdminPolicy.Actions)+len(config.AdminPolicy.Users) > 0 { + return false + } + + anonymousPolicyPresent := false + + for _, repository := range config.Repositories { + // Check if repository has default policy + if len(repository.DefaultPolicy) > 0 { + return false + } + + // Check if repository has anonymous policy + if len(repository.AnonymousPolicy) > 0 { + anonymousPolicyPresent = true + } + + // Check if repository has any non-empty policies + for _, policy := range repository.Policies { + if len(policy.Actions)+len(policy.Users) > 0 { + return false + } + } + } + + return anonymousPolicyPresent +} + +// GetRepositories safely gets a copy of the repositories configuration. +func (config *AccessControlConfig) GetRepositories() Repositories { + if config == nil { + return nil + } + + // Return a copy to avoid race conditions + reposCopy := make(Repositories) + for k, v := range config.Repositories { + reposCopy[k] = v + } + + return reposCopy +} + +// GetAdminPolicy safely gets a copy of the admin policy. +func (config *AccessControlConfig) GetAdminPolicy() Policy { + if config == nil { + return Policy{} + } + + return config.AdminPolicy +} + +// GetMetrics safely gets a copy of the metrics configuration. +func (config *AccessControlConfig) GetMetrics() Metrics { + if config == nil { + return Metrics{} + } + + return config.Metrics +} + +// GetGroups safely gets a copy of the groups configuration. +func (config *AccessControlConfig) GetGroups() Groups { + if config == nil { + return nil + } + + // Return a copy to avoid race conditions + groupsCopy := make(Groups) + for k, v := range config.Groups { + groupsCopy[k] = v + } + + return groupsCopy +} + type ( Repositories map[string]PolicyGroup Groups map[string]Group @@ -275,6 +437,9 @@ type Config struct { Extensions *extconf.ExtensionConfig Scheduler *SchedulerConfig `json:"scheduler" mapstructure:",omitempty"` Cluster *ClusterConfig `json:"cluster" mapstructure:",omitempty"` + + // Mutex to protect concurrent access to config fields + mu sync.RWMutex } func New() *Config { @@ -303,35 +468,107 @@ func (expConfig StorageConfig) ParamsEqual(actConfig StorageConfig) bool { expConfig.GCDelay == actConfig.GCDelay && expConfig.GCInterval == actConfig.GCInterval } -// SameFile compare two files. -// This method will first do the stat of two file and compare using os.SameFile method. -func SameFile(str1, str2 string) (bool, error) { - sFile, err := os.Stat(str1) - if err != nil { - return false, err +// isRetentionEnabledInternal checks if retention is enabled without acquiring a lock (internal use only). +func (c *Config) isRetentionEnabledInternal() bool { + if c == nil { + return false } - tFile, err := os.Stat(str2) - if err != nil { - return false, err + var needsMetaDB bool + + for _, retentionPolicy := range c.Storage.Retention.Policies { + for _, tagRetentionPolicy := range retentionPolicy.KeepTags { + if c.isTagsRetentionEnabled(tagRetentionPolicy) { + needsMetaDB = true + } + } } - return os.SameFile(sFile, tFile), nil + for _, subpath := range c.Storage.SubPaths { + for _, retentionPolicy := range subpath.Retention.Policies { + for _, tagRetentionPolicy := range retentionPolicy.KeepTags { + if c.isTagsRetentionEnabled(tagRetentionPolicy) { + needsMetaDB = true + } + } + } + } + + return needsMetaDB } -func DeepCopy(src, dst interface{}) error { - bytes, err := json.Marshal(src) - if err != nil { - return err +// isTagsRetentionEnabled checks if tags retention is enabled for a specific policy (internal use only). +func (c *Config) isTagsRetentionEnabled(tagRetentionPolicy KeepTagsPolicy) bool { + if tagRetentionPolicy.MostRecentlyPulledCount != 0 || + tagRetentionPolicy.MostRecentlyPushedCount != 0 || + tagRetentionPolicy.PulledWithin != nil || + tagRetentionPolicy.PushedWithin != nil { + return true } - err = json.Unmarshal(bytes, dst) + return false +} - return err +// isBasicAuthnEnabled checks if any basic authentication method is enabled (internal, no locking). +func (c *Config) isBasicAuthnEnabled() bool { + if c == nil { + return false + } + + // Check HTPasswd + if c.HTTP.Auth != nil && c.HTTP.Auth.HTPasswd.Path != "" { + return true + } + + // Check LDAP + if c.HTTP.Auth != nil && c.HTTP.Auth.LDAP != nil { + return true + } + + // Check API Key + if c.HTTP.Auth != nil && c.HTTP.Auth.APIKey { + return true + } + + // Check OpenID + if c.HTTP.Auth != nil && c.HTTP.Auth.OpenID != nil { + for provider := range c.HTTP.Auth.OpenID.Providers { + if isOpenIDAuthProviderEnabled(c, provider) { + return true + } + } + } + + return false +} + +// isOpenIDAuthProviderEnabled checks if a specific OpenID provider is enabled (internal use only). +func isOpenIDAuthProviderEnabled(config *Config, provider string) bool { + if providerConfig, ok := config.HTTP.Auth.OpenID.Providers[provider]; ok { + if IsOpenIDSupported(provider) { + if providerConfig.ClientID != "" || providerConfig.Issuer != "" || + len(providerConfig.Scopes) > 0 { + return true + } + } else if IsOauth2Supported(provider) { + if providerConfig.ClientID != "" || len(providerConfig.Scopes) > 0 { + return true + } + } + } + + return false } // Sanitize makes a sanitized copy of the config removing any secrets. func (c *Config) Sanitize() *Config { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + sanitizedConfig := &Config{} if err := DeepCopy(c, sanitizedConfig); err != nil { @@ -370,7 +607,7 @@ func (c *Config) Sanitize() *Config { } } - if c.IsEventRecorderEnabled() { + if c.Extensions.IsEventRecorderEnabled() { for i, sink := range c.Extensions.Events.Sinks { if sink.Credentials == nil { continue @@ -387,24 +624,367 @@ func (c *Config) Sanitize() *Config { return sanitizedConfig } -func (c *Config) IsLdapAuthEnabled() bool { - if c.HTTP.Auth != nil && c.HTTP.Auth.LDAP != nil { - return true +// UpdateReloadableConfig updates only the fields that can be reloaded at runtime. +func (c *Config) UpdateReloadableConfig(newConfig *Config) { + if c == nil { + return } - return false + c.mu.Lock() + defer c.mu.Unlock() + + // Update storage configuration + c.Storage.GC = newConfig.Storage.GC + c.Storage.Dedupe = newConfig.Storage.Dedupe + c.Storage.GCDelay = newConfig.Storage.GCDelay + c.Storage.GCInterval = newConfig.Storage.GCInterval + + // Only update retention if we have a metaDB already in place + if c.isRetentionEnabledInternal() { + c.Storage.Retention = newConfig.Storage.Retention + } + + // Update storage subpaths configuration + for subPath, storageConfig := range newConfig.Storage.SubPaths { + subPathConfig, ok := c.Storage.SubPaths[subPath] + if !ok { + continue + } + + subPathConfig.GC = storageConfig.GC + subPathConfig.Dedupe = storageConfig.Dedupe + subPathConfig.GCDelay = storageConfig.GCDelay + subPathConfig.GCInterval = storageConfig.GCInterval + + // Only update retention if we have a metaDB already in place + if c.isRetentionEnabledInternal() { + subPathConfig.Retention = storageConfig.Retention + } + + c.Storage.SubPaths[subPath] = subPathConfig + } + + // Update authentication configuration + if c.HTTP.Auth != nil && newConfig.HTTP.Auth != nil { + c.HTTP.Auth.HTPasswd = newConfig.HTTP.Auth.HTPasswd + c.HTTP.Auth.LDAP = newConfig.HTTP.Auth.LDAP + c.HTTP.Auth.APIKey = newConfig.HTTP.Auth.APIKey + c.HTTP.Auth.OpenID = newConfig.HTTP.Auth.OpenID + } + + // Initialize and update AccessControlConfig + if newConfig.HTTP.AccessControl != nil && c.HTTP.AccessControl == nil { + c.HTTP.AccessControl = &AccessControlConfig{} + } + + if newConfig.HTTP.AccessControl == nil { + c.HTTP.AccessControl = nil + } else { + // Update AccessControlConfig fields directly + c.HTTP.AccessControl.Repositories = newConfig.HTTP.AccessControl.Repositories + c.HTTP.AccessControl.AdminPolicy = newConfig.HTTP.AccessControl.AdminPolicy + c.HTTP.AccessControl.Metrics = newConfig.HTTP.AccessControl.Metrics + c.HTTP.AccessControl.Groups = newConfig.HTTP.AccessControl.Groups + } + + // Initialize and update ExtensionConfig + if newConfig.Extensions != nil && c.Extensions == nil { + c.Extensions = &extconf.ExtensionConfig{} + } + + if newConfig.Extensions == nil { + c.Extensions = nil + } else if c.Extensions != nil { + // Update sync extension + c.Extensions.Sync = newConfig.Extensions.Sync + + // Update search extension + if newConfig.Extensions.Search != nil && newConfig.Extensions.Search.CVE != nil { + // Only update if search is enabled + if c.Extensions.IsSearchEnabled() { + if c.Extensions.Search != nil { + c.Extensions.Search.CVE = newConfig.Extensions.Search.CVE + } + } + } else { + // Remove search CVE config if not present in new config + if c.Extensions.Search != nil { + c.Extensions.Search.CVE = nil + } + } + + // Update scrub extension + c.Extensions.Scrub = newConfig.Extensions.Scrub + } } -func (c *Config) IsAuthzEnabled() bool { - return c.HTTP.AccessControl != nil +// CopyAuthConfig returns a copy of the auth config if it exists. +func (c *Config) CopyAuthConfig() *AuthConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.HTTP.Auth == nil { + return nil + } + + // Return a deep copy using tiendc/go-deepcopy to avoid race conditions + authCopy := &AuthConfig{} + _ = deepcopy.Copy(authCopy, c.HTTP.Auth) + + return authCopy } +// CopyAccessControlConfig returns a copy of the access control config if it exists. +func (c *Config) CopyAccessControlConfig() *AccessControlConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.HTTP.AccessControl == nil { + return nil + } + + // Return a deep copy using tiendc/go-deepcopy to avoid race conditions + accessControlCopy := &AccessControlConfig{} + _ = deepcopy.Copy(accessControlCopy, c.HTTP.AccessControl) + + return accessControlCopy +} + +// CopyStorageConfig returns a copy of the storage config. +func (c *Config) CopyStorageConfig() GlobalStorageConfig { + if c == nil { + return GlobalStorageConfig{} + } + + c.mu.RLock() + defer c.mu.RUnlock() + + // Return a deep copy using tiendc/go-deepcopy to avoid race conditions + storageCopy := GlobalStorageConfig{} + _ = deepcopy.Copy(&storageCopy, &c.Storage) + + return storageCopy +} + +// CopyExtensionsConfig returns a copy of the extensions config if it exists. +func (c *Config) CopyExtensionsConfig() *extconf.ExtensionConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.Extensions == nil { + return nil + } + + // Return a deep copy using tiendc/go-deepcopy to avoid race conditions + extensionsCopy := &extconf.ExtensionConfig{} + _ = deepcopy.Copy(extensionsCopy, c.Extensions) + + return extensionsCopy +} + +// CopyLogConfig returns a copy of the log config if it exists. +func (c *Config) CopyLogConfig() *LogConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.Log == nil { + return nil + } + + // Return a copy to avoid race conditions + logCopy := *c.Log + + return &logCopy +} + +// CopyClusterConfig returns a copy of the cluster config if it exists. +func (c *Config) CopyClusterConfig() *ClusterConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.Cluster == nil { + return nil + } + + // Return a deep copy using tiendc/go-deepcopy to avoid race conditions + clusterCopy := &ClusterConfig{} + _ = deepcopy.Copy(clusterCopy, c.Cluster) + + return clusterCopy +} + +// CopySchedulerConfig returns a copy of the scheduler config if it exists. +func (c *Config) CopySchedulerConfig() *SchedulerConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.Scheduler == nil { + return nil + } + + // Return a copy to avoid race conditions + schedulerCopy := *c.Scheduler + + return &schedulerCopy +} + +// CopyTLSConfig returns a copy of the TLS config. +func (c *Config) CopyTLSConfig() *TLSConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.HTTP.TLS == nil { + return nil + } + + // Return a copy to avoid race conditions + tlsCopy := *c.HTTP.TLS + + return &tlsCopy +} + +// CopyRatelimit returns a copy of the rate limit config. +func (c *Config) CopyRatelimit() *RatelimitConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.HTTP.Ratelimit == nil { + return nil + } + + // Return a deep copy using tiendc/go-deepcopy to avoid race conditions + ratelimitCopy := &RatelimitConfig{} + _ = deepcopy.Copy(ratelimitCopy, c.HTTP.Ratelimit) + + return ratelimitCopy +} + +// GetVersionInfo returns version information (read-only, safe to access directly). +func (c *Config) GetVersionInfo() (string, string, string, string) { + if c == nil { + return "", "", "", "" + } + + c.mu.RLock() + defer c.mu.RUnlock() + + return c.Commit, c.BinaryType, c.GoVersion, c.DistSpecVersion +} + +// GetRealm returns the HTTP realm value. +func (c *Config) GetRealm() string { + if c == nil { + return "" + } + + c.mu.RLock() + defer c.mu.RUnlock() + + return c.HTTP.Realm +} + +// GetCompat returns a copy of the compatibility config. +func (c *Config) GetCompat() []compat.MediaCompatibility { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.HTTP.Compat == nil { + return nil + } + + // Return a copy to avoid race conditions + compatCopy := make([]compat.MediaCompatibility, len(c.HTTP.Compat)) + copy(compatCopy, c.HTTP.Compat) + + return compatCopy +} + +// GetHTTPAddress returns the HTTP address. +func (c *Config) GetHTTPAddress() string { + if c == nil { + return "" + } + + c.mu.RLock() + defer c.mu.RUnlock() + + return c.HTTP.Address +} + +// GetHTTPPort returns the HTTP port. +func (c *Config) GetHTTPPort() string { + if c == nil { + return "" + } + + c.mu.RLock() + defer c.mu.RUnlock() + + return c.HTTP.Port +} + +// GetAllowOrigin returns the CORS allow origin configuration. +func (c *Config) GetAllowOrigin() string { + if c == nil { + return "" + } + + c.mu.RLock() + defer c.mu.RUnlock() + + return c.HTTP.AllowOrigin +} + +// IsMTLSAuthEnabled checks if mTLS authentication is enabled. func (c *Config) IsMTLSAuthEnabled() bool { + if c == nil { + return false + } + + c.mu.RLock() + defer c.mu.RUnlock() + if c.HTTP.TLS != nil && c.HTTP.TLS.Key != "" && c.HTTP.TLS.Cert != "" && c.HTTP.TLS.CACert != "" && - !c.IsBasicAuthnEnabled() && + !c.isBasicAuthnEnabled() && !c.HTTP.AccessControl.AnonymousPolicyExists() { return true } @@ -412,157 +992,31 @@ func (c *Config) IsMTLSAuthEnabled() bool { return false } -func (c *Config) IsHtpasswdAuthEnabled() bool { - if c.HTTP.Auth != nil && c.HTTP.Auth.HTPasswd.Path != "" { - return true - } - - return false -} - -func (c *Config) IsBearerAuthEnabled() bool { - if c.HTTP.Auth != nil && - c.HTTP.Auth.Bearer != nil && - c.HTTP.Auth.Bearer.Cert != "" && - c.HTTP.Auth.Bearer.Realm != "" && - c.HTTP.Auth.Bearer.Service != "" { - return true - } - - return false -} - -func (c *Config) IsOpenIDAuthEnabled() bool { - if c.HTTP.Auth != nil && - c.HTTP.Auth.OpenID != nil { - for provider := range c.HTTP.Auth.OpenID.Providers { - if isOpenIDAuthProviderEnabled(c, provider) { - return true - } - } - } - - return false -} - -func (c *Config) IsAPIKeyEnabled() bool { - if c.HTTP.Auth != nil && c.HTTP.Auth.APIKey { - return true - } - - return false -} - -func (c *Config) IsBasicAuthnEnabled() bool { - if c.IsHtpasswdAuthEnabled() || c.IsLdapAuthEnabled() || - c.IsOpenIDAuthEnabled() || c.IsAPIKeyEnabled() { - return true - } - - return false -} - -func isOpenIDAuthProviderEnabled(config *Config, provider string) bool { - if providerConfig, ok := config.HTTP.Auth.OpenID.Providers[provider]; ok { - if IsOpenIDSupported(provider) { - if providerConfig.ClientID != "" || providerConfig.Issuer != "" || - len(providerConfig.Scopes) > 0 { - return true - } - } else if IsOauth2Supported(provider) { - if providerConfig.ClientID != "" || len(providerConfig.Scopes) > 0 { - return true - } - } - } - - return false -} - -func (c *Config) IsMetricsEnabled() bool { - return c.Extensions != nil && c.Extensions.Metrics != nil && *c.Extensions.Metrics.Enable -} - -func (c *Config) IsSearchEnabled() bool { - return c.Extensions != nil && c.Extensions.Search != nil && *c.Extensions.Search.Enable -} - -func (c *Config) IsCveScanningEnabled() bool { - return c.IsSearchEnabled() && c.Extensions.Search.CVE != nil -} - -func (c *Config) IsUIEnabled() bool { - return c.Extensions != nil && c.Extensions.UI != nil && *c.Extensions.UI.Enable -} - -func (c *Config) AreUserPrefsEnabled() bool { - return c.IsSearchEnabled() && c.IsUIEnabled() -} - -func (c *Config) IsMgmtEnabled() bool { - return c.IsSearchEnabled() -} - -func (c *Config) IsImageTrustEnabled() bool { - return c.Extensions != nil && c.Extensions.Trust != nil && *c.Extensions.Trust.Enable -} - -// check if tags retention is enabled. +// IsRetentionEnabled checks if tags retention is enabled. func (c *Config) IsRetentionEnabled() bool { - var needsMetaDB bool - - for _, retentionPolicy := range c.Storage.Retention.Policies { - for _, tagRetentionPolicy := range retentionPolicy.KeepTags { - if c.isTagsRetentionEnabled(tagRetentionPolicy) { - needsMetaDB = true - } - } + if c == nil { + return false } - for _, subpath := range c.Storage.SubPaths { - for _, retentionPolicy := range subpath.Retention.Policies { - for _, tagRetentionPolicy := range retentionPolicy.KeepTags { - if c.isTagsRetentionEnabled(tagRetentionPolicy) { - needsMetaDB = true - } - } - } - } + c.mu.RLock() + defer c.mu.RUnlock() - return needsMetaDB -} - -func (c *Config) isTagsRetentionEnabled(tagRetentionPolicy KeepTagsPolicy) bool { - if tagRetentionPolicy.MostRecentlyPulledCount != 0 || - tagRetentionPolicy.MostRecentlyPushedCount != 0 || - tagRetentionPolicy.PulledWithin != nil || - tagRetentionPolicy.PushedWithin != nil { - return true - } - - return false -} - -func (c *Config) IsCosignEnabled() bool { - return c.IsImageTrustEnabled() && c.Extensions.Trust.Cosign -} - -func (c *Config) IsNotationEnabled() bool { - return c.IsImageTrustEnabled() && c.Extensions.Trust.Notation -} - -func (c *Config) IsSyncEnabled() bool { - return c.Extensions != nil && c.Extensions.Sync != nil && *c.Extensions.Sync.Enable + return c.isRetentionEnabledInternal() } +// IsCompatEnabled checks if compatibility mode is enabled. func (c *Config) IsCompatEnabled() bool { + if c == nil { + return false + } + + c.mu.RLock() + defer c.mu.RUnlock() + return len(c.HTTP.Compat) > 0 } -func (c *Config) IsEventRecorderEnabled() bool { - return c.Extensions != nil && c.Extensions.Events != nil && *c.Extensions.Events.Enable -} - +// IsOpenIDSupported checks if the provider supports OpenID. func IsOpenIDSupported(provider string) bool { for _, supportedProvider := range openIDSupportedProviders { if supportedProvider == provider { @@ -573,6 +1027,7 @@ func IsOpenIDSupported(provider string) bool { return false } +// IsOauth2Supported checks if the provider supports OAuth2. func IsOauth2Supported(provider string) bool { for _, supportedProvider := range oauth2SupportedProviders { if supportedProvider == provider { @@ -582,3 +1037,31 @@ func IsOauth2Supported(provider string) bool { return false } + +// SameFile compare two files. +// This method will first do the stat of two file and compare using os.SameFile method. +func SameFile(str1, str2 string) (bool, error) { + sFile, err := os.Stat(str1) + if err != nil { + return false, err + } + + tFile, err := os.Stat(str2) + if err != nil { + return false, err + } + + return os.SameFile(sFile, tFile), nil +} + +// DeepCopy performs a deep copy of src into dst using JSON marshaling/unmarshaling. +func DeepCopy(src, dst interface{}) error { + bytes, err := json.Marshal(src) + if err != nil { + return err + } + + err = json.Unmarshal(bytes, dst) + + return err +} diff --git a/pkg/api/config/config_test.go b/pkg/api/config/config_test.go index 8eaaabd6..b5127300 100644 --- a/pkg/api/config/config_test.go +++ b/pkg/api/config/config_test.go @@ -7,8 +7,10 @@ import ( . "github.com/smartystreets/goconvey/convey" "zotregistry.dev/zot/v2/pkg/api/config" + "zotregistry.dev/zot/v2/pkg/compat" extconf "zotregistry.dev/zot/v2/pkg/extensions/config" - "zotregistry.dev/zot/v2/pkg/extensions/config/events" + eventsconf "zotregistry.dev/zot/v2/pkg/extensions/config/events" + syncconf "zotregistry.dev/zot/v2/pkg/extensions/config/sync" ) func TestConfig(t *testing.T) { @@ -69,26 +71,333 @@ func TestConfig(t *testing.T) { }) Convey("Test DeepCopy() & Sanitize()", t, func() { - conf := config.New() - So(conf, ShouldNotBeNil) + Convey("Test DeepCopy negative cases", func() { + conf := config.New() + So(conf, ShouldNotBeNil) - authConfig := &config.AuthConfig{LDAP: (&config.LDAPConfig{}).SetBindPassword("oina")} - conf.HTTP.Auth = authConfig + // negative + obj := make(chan int) + err := config.DeepCopy(conf, obj) + So(err, ShouldNotBeNil) + err = config.DeepCopy(obj, conf) + So(err, ShouldNotBeNil) + }) - So(func() { conf.Sanitize() }, ShouldNotPanic) + Convey("Test Sanitize() with LDAP bind password", func() { + conf := config.New() + So(conf, ShouldNotBeNil) - conf = conf.Sanitize() - So(conf.HTTP.Auth.LDAP.BindPassword(), ShouldEqual, "******") + // Set LDAP bind password + authConfig := &config.AuthConfig{LDAP: (&config.LDAPConfig{}).SetBindPassword("secret-ldap-password")} + conf.HTTP.Auth = authConfig - // negative - obj := make(chan int) - err := config.DeepCopy(conf, obj) - So(err, ShouldNotBeNil) - err = config.DeepCopy(obj, conf) - So(err, ShouldNotBeNil) + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + So(sanitizedConf.HTTP.Auth.LDAP.BindPassword(), ShouldEqual, "******") + + // Verify original config is not modified + So(conf.HTTP.Auth.LDAP.BindPassword(), ShouldEqual, "secret-ldap-password") + }) + + Convey("Test Sanitize() with OpenID client secrets", func() { + conf := config.New() + So(conf, ShouldNotBeNil) + + // Set OpenID client secrets + authConfig := &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + Name: "Google", + ClientID: "google-client-id", + ClientSecret: "google-client-secret", + Issuer: "https://accounts.google.com", + Scopes: []string{"openid", "email"}, + }, + "github": { + Name: "GitHub", + ClientID: "github-client-id", + ClientSecret: "github-client-secret", + Scopes: []string{"user:email"}, + }, + }, + }, + } + conf.HTTP.Auth = authConfig + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + + // Verify OpenID client secrets are sanitized + So(sanitizedConf.HTTP.Auth.OpenID.Providers["google"].ClientSecret, ShouldEqual, "******") + So(sanitizedConf.HTTP.Auth.OpenID.Providers["github"].ClientSecret, ShouldEqual, "******") + + // Verify other fields are preserved + So(sanitizedConf.HTTP.Auth.OpenID.Providers["google"].ClientID, ShouldEqual, "google-client-id") + So(sanitizedConf.HTTP.Auth.OpenID.Providers["google"].Name, ShouldEqual, "Google") + So(sanitizedConf.HTTP.Auth.OpenID.Providers["google"].Issuer, ShouldEqual, "https://accounts.google.com") + So(sanitizedConf.HTTP.Auth.OpenID.Providers["google"].Scopes, ShouldResemble, []string{"openid", "email"}) + + // Verify original config is not modified + So(conf.HTTP.Auth.OpenID.Providers["google"].ClientSecret, ShouldEqual, "google-client-secret") + So(conf.HTTP.Auth.OpenID.Providers["github"].ClientSecret, ShouldEqual, "github-client-secret") + }) + + Convey("Test Sanitize() with Event sink credentials", func() { + conf := config.New() + So(conf, ShouldNotBeNil) + + // Enable events extension and set sink credentials + enabled := true + conf.Extensions = &extconf.ExtensionConfig{ + Events: &eventsconf.Config{ + Enable: &enabled, + Sinks: []eventsconf.SinkConfig{ + { + Type: eventsconf.HTTP, + Address: "https://example.com/webhook", + Credentials: &eventsconf.Credentials{ + Username: "webhook-user", + Password: "webhook-password", + }, + }, + { + Type: eventsconf.NATS, + Address: "nats://localhost:4222", + Credentials: &eventsconf.Credentials{ + Username: "nats-user", + Password: "nats-token", + }, + }, + }, + }, + } + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + + // Verify event sink credentials passwords are sanitized + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Password, ShouldEqual, "******") + So(sanitizedConf.Extensions.Events.Sinks[1].Credentials.Password, ShouldEqual, "******") + + // Verify other fields are preserved + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Username, ShouldEqual, "webhook-user") + So(sanitizedConf.Extensions.Events.Sinks[1].Credentials.Username, ShouldEqual, "nats-user") + So(sanitizedConf.Extensions.Events.Sinks[0].Type, ShouldEqual, eventsconf.HTTP) + So(sanitizedConf.Extensions.Events.Sinks[1].Type, ShouldEqual, eventsconf.NATS) + + // Verify original config is not modified + So(conf.Extensions.Events.Sinks[0].Credentials.Password, ShouldEqual, "webhook-password") + So(conf.Extensions.Events.Sinks[1].Credentials.Password, ShouldEqual, "nats-token") + }) + + Convey("Test Sanitize() with Event sink credentials including nil credentials", func() { + conf := config.New() + So(conf, ShouldNotBeNil) + + // Enable events extension with mixed sink credentials (some nil, some not) + enabled := true + conf.Extensions = &extconf.ExtensionConfig{ + Events: &eventsconf.Config{ + Enable: &enabled, + Sinks: []eventsconf.SinkConfig{ + { + Type: eventsconf.HTTP, + Address: "https://example.com/webhook", + Credentials: &eventsconf.Credentials{ + Username: "webhook-user", + Password: "webhook-password", + }, + }, + { + Type: eventsconf.NATS, + Address: "nats://localhost:4222", + Credentials: nil, // This should trigger the continue statement + }, + { + Type: eventsconf.HTTP, + Address: "https://another.com/webhook", + Credentials: &eventsconf.Credentials{ + Username: "another-user", + Password: "another-password", + }, + }, + }, + }, + } + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + + // Verify that sinks with credentials have their passwords sanitized + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Password, ShouldEqual, "******") + So(sanitizedConf.Extensions.Events.Sinks[2].Credentials.Password, ShouldEqual, "******") + + // Verify that sink with nil credentials is preserved as-is (no panic, no modification) + So(sanitizedConf.Extensions.Events.Sinks[1].Credentials, ShouldBeNil) + So(sanitizedConf.Extensions.Events.Sinks[1].Type, ShouldEqual, eventsconf.NATS) + So(sanitizedConf.Extensions.Events.Sinks[1].Address, ShouldEqual, "nats://localhost:4222") + + // Verify other fields are preserved + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Username, ShouldEqual, "webhook-user") + So(sanitizedConf.Extensions.Events.Sinks[2].Credentials.Username, ShouldEqual, "another-user") + So(sanitizedConf.Extensions.Events.Sinks[0].Type, ShouldEqual, eventsconf.HTTP) + So(sanitizedConf.Extensions.Events.Sinks[2].Type, ShouldEqual, eventsconf.HTTP) + + // Verify original config is not modified + So(conf.Extensions.Events.Sinks[0].Credentials.Password, ShouldEqual, "webhook-password") + So(conf.Extensions.Events.Sinks[2].Credentials.Password, ShouldEqual, "another-password") + So(conf.Extensions.Events.Sinks[1].Credentials, ShouldBeNil) + }) + + Convey("Test Sanitize() with all sensitive data types", func() { + conf := config.New() + So(conf, ShouldNotBeNil) + + // Set all types of sensitive data + authConfig := &config.AuthConfig{ + LDAP: (&config.LDAPConfig{}).SetBindPassword("ldap-secret"), + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "azure": { + Name: "Azure AD", + ClientID: "azure-client-id", + ClientSecret: "azure-client-secret", + Issuer: "https://login.microsoftonline.com/...", + }, + }, + }, + } + conf.HTTP.Auth = authConfig + + // Enable events extension + enabled := true + conf.Extensions = &extconf.ExtensionConfig{ + Events: &eventsconf.Config{ + Enable: &enabled, + Sinks: []eventsconf.SinkConfig{ + { + Type: eventsconf.HTTP, + Address: "https://smtp.example.com/webhook", + Credentials: &eventsconf.Credentials{ + Username: "smtp-user", + Password: "smtp-password", + }, + }, + }, + }, + } + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + + // Verify all sensitive data is sanitized + So(sanitizedConf.HTTP.Auth.LDAP.BindPassword(), ShouldEqual, "******") + So(sanitizedConf.HTTP.Auth.OpenID.Providers["azure"].ClientSecret, ShouldEqual, "******") + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Password, ShouldEqual, "******") + + // Verify non-sensitive data is preserved + So(sanitizedConf.HTTP.Auth.OpenID.Providers["azure"].ClientID, ShouldEqual, "azure-client-id") + So(sanitizedConf.HTTP.Auth.OpenID.Providers["azure"].Name, ShouldEqual, "Azure AD") + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Username, ShouldEqual, "smtp-user") + So(sanitizedConf.Extensions.Events.Sinks[0].Type, ShouldEqual, eventsconf.HTTP) + }) + + Convey("Test Sanitize() with nil sensitive data", func() { + conf := config.New() + So(conf, ShouldNotBeNil) + + // Set config with nil sensitive data + authConfig := &config.AuthConfig{ + LDAP: nil, // No LDAP config + OpenID: nil, // No OpenID config + } + conf.HTTP.Auth = authConfig + + // No events extension + conf.Extensions = nil + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + + // Verify nil configs are handled gracefully + So(sanitizedConf.HTTP.Auth.LDAP, ShouldBeNil) + So(sanitizedConf.HTTP.Auth.OpenID, ShouldBeNil) + So(sanitizedConf.Extensions, ShouldBeNil) + }) + + Convey("Test Sanitize() with empty sensitive data", func() { + conf := config.New() + So(conf, ShouldNotBeNil) + + // Set config with empty sensitive data + authConfig := &config.AuthConfig{ + LDAP: (&config.LDAPConfig{}).SetBindPassword(""), // Empty password + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "empty": { + Name: "Empty Provider", + ClientID: "empty-client-id", + ClientSecret: "", // Empty secret + }, + }, + }, + } + conf.HTTP.Auth = authConfig + + // Enable events extension with empty password + enabled := true + conf.Extensions = &extconf.ExtensionConfig{ + Events: &eventsconf.Config{ + Enable: &enabled, + Sinks: []eventsconf.SinkConfig{ + { + Type: eventsconf.HTTP, + Address: "https://example.com/webhook", + Credentials: &eventsconf.Credentials{ + Username: "user", + Password: "", // Empty password + }, + }, + }, + }, + } + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + + // Verify empty passwords behavior + // LDAP empty password should remain empty + So(sanitizedConf.HTTP.Auth.LDAP.BindPassword(), ShouldEqual, "") + // OpenID empty secret is always sanitized + So(sanitizedConf.HTTP.Auth.OpenID.Providers["empty"].ClientSecret, ShouldEqual, "******") + // Event sink empty password is always sanitized + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Password, ShouldEqual, "******") + }) + + Convey("Test Sanitize() with nil config", func() { + var conf *config.Config = nil + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + So(sanitizedConf, ShouldBeNil) + }) }) Convey("Test IsRetentionEnabled()", t, func() { + // Test nil config + var nilConf *config.Config = nil + + So(nilConf.IsRetentionEnabled(), ShouldBeFalse) + conf := config.New() So(conf.IsRetentionEnabled(), ShouldBeFalse) @@ -130,31 +439,2960 @@ func TestConfig(t *testing.T) { conf.Storage.SubPaths = subPaths So(conf.IsRetentionEnabled(), ShouldBeTrue) + + // Test MostRecentlyPushedCount + conf = config.New() + conf.Storage.Retention.Policies = []config.RetentionPolicy{ + { + Repositories: []string{"repo"}, + KeepTags: []config.KeepTagsPolicy{ + { + Patterns: []string{"tag"}, + MostRecentlyPushedCount: 3, + }, + }, + }, + } + So(conf.IsRetentionEnabled(), ShouldBeTrue) + + // Test PulledWithin + conf = config.New() + duration := time.Hour * 24 + conf.Storage.Retention.Policies = []config.RetentionPolicy{ + { + Repositories: []string{"repo"}, + KeepTags: []config.KeepTagsPolicy{ + { + Patterns: []string{"tag"}, + PulledWithin: &duration, + }, + }, + }, + } + So(conf.IsRetentionEnabled(), ShouldBeTrue) + + // Test PushedWithin + conf = config.New() + conf.Storage.Retention.Policies = []config.RetentionPolicy{ + { + Repositories: []string{"repo"}, + KeepTags: []config.KeepTagsPolicy{ + { + Patterns: []string{"tag"}, + PushedWithin: &duration, + }, + }, + }, + } + So(conf.IsRetentionEnabled(), ShouldBeTrue) + + // Test SubPaths with retention policies + conf = config.New() + conf.Storage.SubPaths = map[string]config.StorageConfig{ + "subpath1": { + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"repo1"}, + KeepTags: []config.KeepTagsPolicy{ + { + Patterns: []string{"latest"}, + MostRecentlyPulledCount: 5, + }, + }, + }, + }, + }, + }, + } + So(conf.IsRetentionEnabled(), ShouldBeTrue) + + // Test empty policies with no retention criteria + conf = config.New() + conf.Storage.Retention.Policies = []config.RetentionPolicy{ + { + Repositories: []string{"repo"}, + KeepTags: []config.KeepTagsPolicy{ + { + Patterns: []string{"tag"}, + // No retention criteria set + }, + }, + }, + } + So(conf.IsRetentionEnabled(), ShouldBeFalse) }) Convey("Test IsEventRecorderEnabled()", t, func() { conf := config.New() - So(conf.IsEventRecorderEnabled(), ShouldBeFalse) + extensionsConfig := conf.CopyExtensionsConfig() + So(extensionsConfig.IsEventRecorderEnabled(), ShouldBeFalse) // Enable the event recorder enable := true conf.Extensions = &extconf.ExtensionConfig{} - conf.Extensions.Events = &events.Config{ + conf.Extensions.Events = &eventsconf.Config{ Enable: &enable, } - So(conf.IsEventRecorderEnabled(), ShouldBeTrue) + extensionsConfig = conf.CopyExtensionsConfig() + So(extensionsConfig.IsEventRecorderEnabled(), ShouldBeTrue) // Disabled scenario disable := false conf.Extensions.Events.Enable = &disable - So(conf.IsEventRecorderEnabled(), ShouldBeFalse) + extensionsConfig = conf.CopyExtensionsConfig() + So(extensionsConfig.IsEventRecorderEnabled(), ShouldBeFalse) // nil pointers conf.Extensions.Events = nil - So(conf.IsEventRecorderEnabled(), ShouldBeFalse) + extensionsConfig = conf.CopyExtensionsConfig() + So(extensionsConfig.IsEventRecorderEnabled(), ShouldBeFalse) conf.Extensions = nil - So(conf.IsEventRecorderEnabled(), ShouldBeFalse) + extensionsConfig = conf.CopyExtensionsConfig() + So(extensionsConfig.IsEventRecorderEnabled(), ShouldBeFalse) + }) + + Convey("Test AccessControlConfig.ContainsOnlyAnonymousPolicy()", t, func() { + Convey("When accessControlConfig is nil", func() { + var accessControlConfig *config.AccessControlConfig = nil + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeTrue) + }) + + Convey("When accessControlConfig has admin policies", func() { + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.AdminPolicy = config.Policy{ + Actions: []string{"read"}, + Users: []string{"admin"}, + } + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeFalse) + }) + + Convey("When accessControlConfig has only anonymous policies", func() { + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Repositories = config.Repositories{ + "repo1": config.PolicyGroup{ + AnonymousPolicy: []string{"read"}, + }, + } + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeTrue) + }) + + Convey("When accessControlConfig has default policies", func() { + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Repositories = config.Repositories{ + "repo1": config.PolicyGroup{ + DefaultPolicy: []string{"read"}, + }, + } + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeFalse) + }) + + Convey("When accessControlConfig has non-empty repository policies", func() { + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Repositories = config.Repositories{ + "repo1": config.PolicyGroup{ + Policies: []config.Policy{ + { + Actions: []string{"read"}, + Users: []string{"user1"}, + }, + }, + }, + } + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeFalse) + }) + + Convey("When accessControlConfig has empty admin policy and no repositories", func() { + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.AdminPolicy = config.Policy{ + Actions: []string{}, + Users: []string{}, + } + accessControlConfig.Repositories = config.Repositories{} + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeFalse) + }) + + Convey("When accessControlConfig has empty policies in repository", func() { + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Repositories = config.Repositories{ + "repo1": config.PolicyGroup{ + AnonymousPolicy: []string{"read"}, + Policies: []config.Policy{ + { + Actions: []string{}, + Users: []string{}, + }, + }, + }, + } + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeTrue) + }) + }) + + Convey("Test AuthConfig methods", t, func() { + Convey("Test IsLdapAuthEnabled()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.IsLdapAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig but nil LDAP + authConfig = &config.AuthConfig{} + So(authConfig.IsLdapAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig and LDAP configured + authConfig = &config.AuthConfig{ + LDAP: &config.LDAPConfig{}, + } + So(authConfig.IsLdapAuthEnabled(), ShouldBeTrue) + }) + + Convey("Test IsHtpasswdAuthEnabled()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig but empty HTPasswd path + authConfig = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{Path: ""}, + } + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig and HTPasswd configured + authConfig = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{Path: "/path/to/htpasswd"}, + } + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) + }) + + Convey("Test IsBearerAuthEnabled()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.IsBearerAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig but nil Bearer + authConfig = &config.AuthConfig{} + So(authConfig.IsBearerAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig and Bearer configured with all required fields + authConfig = &config.AuthConfig{ + Bearer: &config.BearerConfig{ + Cert: "/path/to/cert.pem", + Realm: "test-realm", + Service: "test-service", + }, + } + So(authConfig.IsBearerAuthEnabled(), ShouldBeTrue) + }) + + Convey("Test IsOpenIDAuthEnabled()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.IsOpenIDAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig but nil OpenID + authConfig = &config.AuthConfig{} + So(authConfig.IsOpenIDAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig and OpenID configured with providers + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "client-id", + }, + }, + }, + } + So(authConfig.IsOpenIDAuthEnabled(), ShouldBeTrue) + }) + + Convey("Test IsAPIKeyEnabled()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) + + // Test with AuthConfig but APIKey disabled + authConfig = &config.AuthConfig{ + APIKey: false, + } + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) + + // Test with AuthConfig and APIKey enabled + authConfig = &config.AuthConfig{ + APIKey: true, + } + So(authConfig.IsAPIKeyEnabled(), ShouldBeTrue) + }) + + Convey("Test IsBasicAuthnEnabled()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.IsBasicAuthnEnabled(), ShouldBeFalse) + + // Test with AuthConfig but no basic auth methods + authConfig = &config.AuthConfig{} + So(authConfig.IsBasicAuthnEnabled(), ShouldBeFalse) + + // Test with HTPasswd enabled + authConfig = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{Path: "/path/to/htpasswd"}, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with LDAP enabled + authConfig = &config.AuthConfig{ + LDAP: &config.LDAPConfig{}, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with OpenID enabled (with ClientID) + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "client-id", + Scopes: []string{"openid", "email"}, + }, + }, + }, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with OpenID enabled (with Issuer) + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "https://accounts.google.com", + Scopes: []string{}, + }, + }, + }, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with OpenID enabled (with Scopes only) + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "", + Scopes: []string{"openid", "email"}, + }, + }, + }, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with OAuth2 provider (github) + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "github": { + ClientID: "github-client-id", + Scopes: []string{"user:email"}, + }, + }, + }, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with OpenID but no valid providers (empty config) + // Note: AuthConfig.IsOpenIDAuthEnabled() only checks if provider is supported, + // not if the configuration is valid, so this returns true + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "", + Scopes: []string{}, + }, + }, + }, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with OpenID but unsupported provider + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "unsupported": { + ClientID: "client-id", + Scopes: []string{"scope"}, + }, + }, + }, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeFalse) + + // Test with APIKey enabled + authConfig = &config.AuthConfig{ + APIKey: true, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + }) + + Convey("Test GetFailDelay()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.GetFailDelay(), ShouldEqual, 0) + + // Test with AuthConfig and custom FailDelay + authConfig = &config.AuthConfig{ + FailDelay: 5, + } + So(authConfig.GetFailDelay(), ShouldEqual, 5) + }) + }) + + Convey("Test LDAPConfig methods", t, func() { + Convey("Test BindDN()", func() { + ldapConfig := &config.LDAPConfig{} + So(ldapConfig.BindDN(), ShouldEqual, "") + + ldapConfig.SetBindDN("cn=admin,dc=example,dc=com") + So(ldapConfig.BindDN(), ShouldEqual, "cn=admin,dc=example,dc=com") + }) + + Convey("Test BindPassword()", func() { + ldapConfig := &config.LDAPConfig{} + So(ldapConfig.BindPassword(), ShouldEqual, "") + + ldapConfig.SetBindPassword("secretpassword") + So(ldapConfig.BindPassword(), ShouldEqual, "secretpassword") + }) + }) + + Convey("Test AccessControlConfig methods", t, func() { + Convey("Test IsAuthzEnabled()", func() { + // Test with nil AccessControlConfig + var accessControlConfig *config.AccessControlConfig = nil + + So(accessControlConfig.IsAuthzEnabled(), ShouldBeFalse) + + // Test with AccessControlConfig + accessControlConfig = &config.AccessControlConfig{} + So(accessControlConfig.IsAuthzEnabled(), ShouldBeTrue) + }) + + Convey("Test AnonymousPolicyExists()", func() { + // Test with nil AccessControlConfig + var accessControlConfig *config.AccessControlConfig = nil + + So(accessControlConfig.AnonymousPolicyExists(), ShouldBeFalse) + + // Test with AccessControlConfig but no repositories + accessControlConfig = &config.AccessControlConfig{} + So(accessControlConfig.AnonymousPolicyExists(), ShouldBeFalse) + + // Test with AccessControlConfig and repository with anonymous policy + accessControlConfig = &config.AccessControlConfig{} + accessControlConfig.Repositories = config.Repositories{ + "repo1": config.PolicyGroup{ + AnonymousPolicy: []string{"read"}, + }, + } + So(accessControlConfig.AnonymousPolicyExists(), ShouldBeTrue) + + // Test with AccessControlConfig and repository without anonymous policy + accessControlConfig = &config.AccessControlConfig{} + accessControlConfig.Repositories = config.Repositories{ + "repo1": config.PolicyGroup{ + DefaultPolicy: []string{"read"}, + }, + } + So(accessControlConfig.AnonymousPolicyExists(), ShouldBeFalse) + }) + + Convey("Test GetRepositories()", func() { + repositories := config.Repositories{ + "repo1": config.PolicyGroup{ + AnonymousPolicy: []string{"read"}, + }, + } + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Repositories = repositories + So(accessControlConfig.GetRepositories(), ShouldResemble, repositories) + }) + + Convey("Test GetAdminPolicy()", func() { + adminPolicy := config.Policy{ + Actions: []string{"read", "write"}, + Users: []string{"admin"}, + } + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.AdminPolicy = adminPolicy + So(accessControlConfig.GetAdminPolicy(), ShouldResemble, adminPolicy) + }) + + Convey("Test GetMetrics()", func() { + metrics := config.Metrics{ + Users: []string{"metrics-user"}, + } + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Metrics = metrics + So(accessControlConfig.GetMetrics(), ShouldResemble, metrics) + }) + + Convey("Test GetGroups()", func() { + groups := config.Groups{ + "developers": config.Group{ + Users: []string{"dev1", "dev2"}, + }, + } + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Groups = groups + So(accessControlConfig.GetGroups(), ShouldResemble, groups) + }) + }) + + Convey("Test Config getter methods", t, func() { + Convey("Test CopyAuthConfig()", func() { + Convey("Test with non-nil Auth", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + FailDelay: 5, + }, + }, + } + authConfig := cfg.CopyAuthConfig() + So(authConfig, ShouldNotBeNil) + So(authConfig.GetFailDelay(), ShouldEqual, 5) + }) + + Convey("Test with nil Auth", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: nil, + }, + } + authConfig := cfg.CopyAuthConfig() + So(authConfig, ShouldBeNil) + }) + + Convey("Test that returned AuthConfig is isolated from config mutations", func() { + // Create initial config with AuthConfig containing nested structures + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + FailDelay: 5, + HTPasswd: config.AuthHTPasswd{ + Path: "/etc/htpasswd", + }, + LDAP: &config.LDAPConfig{ + Address: "ldap.example.com", + Port: 389, + }, + Bearer: &config.BearerConfig{ + Realm: "test-realm", + Service: "test-service", + Cert: "/path/to/cert", + }, + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + Name: "Google", + ClientID: "google-client-id", + Scopes: []string{"openid", "email"}, + }, + }, + }, + APIKey: false, + SessionKeysFile: "/etc/session-keys", + SessionHashKey: []byte("hash-key"), + SessionEncryptKey: []byte("encrypt-key"), + SessionDriver: map[string]any{ + "type": "redis", + "host": "localhost", + }, + }, + }, + } + + // Get the AuthConfig reference + authConfig := cfg.CopyAuthConfig() + So(authConfig, ShouldNotBeNil) + So(authConfig.GetFailDelay(), ShouldEqual, 5) + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) + So(authConfig.IsLdapAuthEnabled(), ShouldBeTrue) + So(authConfig.IsBearerAuthEnabled(), ShouldBeTrue) + So(authConfig.IsOpenIDAuthEnabled(), ShouldBeTrue) + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) + + // Test deep copy isolation by modifying nested structures + authConfig.LDAP.Address = "modified-ldap.example.com" + authConfig.Bearer.Realm = "modified-realm" + authConfig.OpenID.Providers["google"].Scopes[0] = "modified-scope" + authConfig.SessionHashKey[0] = 'M' + authConfig.SessionDriver["type"] = "modified-driver" + + // Verify original is unchanged + So(cfg.HTTP.Auth.LDAP.Address, ShouldEqual, "ldap.example.com") + So(cfg.HTTP.Auth.Bearer.Realm, ShouldEqual, "test-realm") + So(cfg.HTTP.Auth.OpenID.Providers["google"].Scopes[0], ShouldEqual, "openid") + So(cfg.HTTP.Auth.SessionHashKey[0], ShouldEqual, byte('h')) + So(cfg.HTTP.Auth.SessionDriver["type"], ShouldEqual, "redis") + }) + + Convey("Test that returned AuthConfig is isolated when config is updated via UpdateReloadableConfig", func() { + // Create initial config with AuthConfig + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + FailDelay: 5, + HTPasswd: config.AuthHTPasswd{ + Path: "/etc/htpasswd", + }, + APIKey: false, + }, + }, + } + + // Get the AuthConfig reference + authConfig := cfg.CopyAuthConfig() + So(authConfig, ShouldNotBeNil) + So(authConfig.GetFailDelay(), ShouldEqual, 5) + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) + + // Create new config with updated AuthConfig + // Note: UpdateReloadableConfig updates HTPasswd, LDAP, APIKey, and OpenID fields + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + FailDelay: 15, // This field is NOT updated by UpdateReloadableConfig + HTPasswd: config.AuthHTPasswd{ + Path: "/etc/updated-htpasswd", // This field IS updated by UpdateReloadableConfig + }, + APIKey: true, // This field IS updated by UpdateReloadableConfig + }, + }, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that the returned AuthConfig is not affected by the update + // CopyAuthConfig() returns a copy, so the returned object should be isolated + So(authConfig.GetFailDelay(), ShouldEqual, 5) // Should remain unchanged + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) // Should remain unchanged (old path) + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) // Should remain unchanged + + // Verify that a new CopyAuthConfig() call returns the updated values + newAuthConfig := cfg.CopyAuthConfig() + So(newAuthConfig, ShouldNotBeNil) + // Should remain unchanged (not updated by UpdateReloadableConfig) + So(newAuthConfig.GetFailDelay(), ShouldEqual, 5) + So(newAuthConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) // Should be updated (new path) + // Should be updated by UpdateReloadableConfig + So(newAuthConfig.IsAPIKeyEnabled(), ShouldBeTrue) + }) + + Convey("Test that returned AuthConfig is isolated when config is set to nil", func() { + // Create initial config with AuthConfig + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + FailDelay: 5, + HTPasswd: config.AuthHTPasswd{ + Path: "/etc/htpasswd", + }, + APIKey: false, + }, + }, + } + + // Get the AuthConfig reference + authConfig := cfg.CopyAuthConfig() + So(authConfig, ShouldNotBeNil) + So(authConfig.GetFailDelay(), ShouldEqual, 5) + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) + + // Set the AuthConfig to nil + cfg.HTTP.Auth = nil + + // Verify that the returned AuthConfig is not affected by setting to nil + So(authConfig, ShouldNotBeNil) // Should remain unchanged + So(authConfig.GetFailDelay(), ShouldEqual, 5) // Should remain unchanged + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) // Should remain unchanged + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) // Should remain unchanged + + // Verify that a new CopyAuthConfig() call returns nil + newAuthConfig := cfg.CopyAuthConfig() + So(newAuthConfig, ShouldBeNil) // Should be nil + }) + }) + + Convey("Test CopyAccessControlConfig()", func() { + Convey("Test with non-nil AccessControl", func() { + testAccessControlConfig := &config.AccessControlConfig{ + Repositories: config.Repositories{ + "repo1": config.PolicyGroup{ + Policies: []config.Policy{ + { + Users: []string{"user1", "user2"}, + Actions: []string{"read", "write"}, + Groups: []string{"group1"}, + }, + }, + DefaultPolicy: []string{"read"}, + AnonymousPolicy: []string{"read"}, + }, + }, + AdminPolicy: config.Policy{ + Users: []string{"admin1"}, + Actions: []string{"read", "write", "delete"}, + Groups: []string{"admin-group"}, + }, + Groups: config.Groups{ + "group1": config.Group{ + Users: []string{"user1", "user2"}, + }, + }, + Metrics: config.Metrics{ + Users: []string{"metrics-user"}, + }, + } + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: testAccessControlConfig, + }, + } + accessControlConfig := cfg.CopyAccessControlConfig() + So(accessControlConfig, ShouldNotBeNil) + So(accessControlConfig.IsAuthzEnabled(), ShouldBeTrue) + + // Test deep copy isolation + accessControlConfig.Repositories["repo1"].Policies[0].Users[0] = "modified-user" + accessControlConfig.Repositories["repo1"].DefaultPolicy[0] = "modified-policy" + accessControlConfig.AdminPolicy.Users[0] = "modified-admin" + accessControlConfig.Groups["group1"].Users[0] = "modified-group-user" + accessControlConfig.Metrics.Users[0] = "modified-metrics-user" + + // Verify original is unchanged + So(cfg.HTTP.AccessControl.Repositories["repo1"].Policies[0].Users[0], ShouldEqual, "user1") + So(cfg.HTTP.AccessControl.Repositories["repo1"].DefaultPolicy[0], ShouldEqual, "read") + So(cfg.HTTP.AccessControl.AdminPolicy.Users[0], ShouldEqual, "admin1") + So(cfg.HTTP.AccessControl.Groups["group1"].Users[0], ShouldEqual, "user1") + So(cfg.HTTP.AccessControl.Metrics.Users[0], ShouldEqual, "metrics-user") + }) + + Convey("Test with nil AccessControl", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: nil, + }, + } + accessControlConfig := cfg.CopyAccessControlConfig() + So(accessControlConfig, ShouldBeNil) + }) + }) + + Convey("Test CopyStorageConfig()", func() { + Convey("Test with non-nil Storage", func() { + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + RootDirectory: "/tmp/storage", + GC: true, + }, + }, + } + storageConfig := cfg.CopyStorageConfig() + So(storageConfig, ShouldNotBeNil) + So(storageConfig.RootDirectory, ShouldEqual, "/tmp/storage") + So(storageConfig.GC, ShouldBeTrue) + }) + + Convey("Test with nil Storage", func() { + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{}, + } + storageConfig := cfg.CopyStorageConfig() + So(storageConfig, ShouldNotBeNil) // GlobalStorageConfig is a struct, not a pointer, so it's never nil + So(storageConfig.RootDirectory, ShouldEqual, "") + So(storageConfig.GC, ShouldBeFalse) + }) + + Convey("Test StorageConfig deep copy isolation", func() { + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + RootDirectory: "/tmp/storage", + GC: true, + Retention: config.ImageRetention{ + DryRun: true, + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"repo1", "repo2"}, + KeepTags: []config.KeepTagsPolicy{ + { + Patterns: []string{"pattern1", "pattern2"}, + }, + }, + }, + }, + }, + StorageDriver: map[string]interface{}{ + "type": "filesystem", + }, + CacheDriver: map[string]interface{}{ + "type": "redis", + }, + }, + SubPaths: map[string]config.StorageConfig{ + "/subpath1": { + RootDirectory: "/tmp/subpath1", + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"subrepo1"}, + }, + }, + }, + StorageDriver: map[string]interface{}{ + "type": "s3", + }, + }, + }, + }, + } + + // Get a copy of the storage config + storageConfig := cfg.CopyStorageConfig() + So(storageConfig, ShouldNotBeNil) + + // Mutate the copy's fields + storageConfig.RootDirectory = "/modified/storage" + storageConfig.GC = false + storageConfig.Retention.Policies[0].Repositories[0] = "modified-repo" + storageConfig.Retention.Policies[0].KeepTags[0].Patterns[0] = "modified-pattern" + storageConfig.StorageDriver["type"] = "modified-driver" + storageConfig.CacheDriver["type"] = "modified-cache" + + // Mutate SubPaths by getting a copy, modifying it, and putting it back + subPathConfig := storageConfig.SubPaths["/subpath1"] + subPathConfig.RootDirectory = "/modified/subpath1" + subPathConfig.Retention.Policies[0].Repositories[0] = "modified-subrepo" + subPathConfig.StorageDriver["type"] = "modified-s3" + storageConfig.SubPaths["/subpath1"] = subPathConfig + + // Verify original config is unchanged + So(cfg.Storage.RootDirectory, ShouldEqual, "/tmp/storage") + So(cfg.Storage.GC, ShouldBeTrue) + So(cfg.Storage.Retention.Policies[0].Repositories[0], ShouldEqual, "repo1") + So(cfg.Storage.Retention.Policies[0].KeepTags[0].Patterns[0], ShouldEqual, "pattern1") + So(cfg.Storage.StorageDriver["type"], ShouldEqual, "filesystem") + So(cfg.Storage.CacheDriver["type"], ShouldEqual, "redis") + So(cfg.Storage.SubPaths["/subpath1"].RootDirectory, ShouldEqual, "/tmp/subpath1") + So(cfg.Storage.SubPaths["/subpath1"].Retention.Policies[0].Repositories[0], ShouldEqual, "subrepo1") + So(cfg.Storage.SubPaths["/subpath1"].StorageDriver["type"], ShouldEqual, "s3") + + // Verify copy has the mutations + So(storageConfig.RootDirectory, ShouldEqual, "/modified/storage") + So(storageConfig.GC, ShouldBeFalse) + So(storageConfig.Retention.Policies[0].Repositories[0], ShouldEqual, "modified-repo") + So(storageConfig.Retention.Policies[0].KeepTags[0].Patterns[0], ShouldEqual, "modified-pattern") + So(storageConfig.StorageDriver["type"], ShouldEqual, "modified-driver") + So(storageConfig.CacheDriver["type"], ShouldEqual, "modified-cache") + So(storageConfig.SubPaths["/subpath1"].RootDirectory, ShouldEqual, "/modified/subpath1") + So(storageConfig.SubPaths["/subpath1"].Retention.Policies[0].Repositories[0], ShouldEqual, "modified-subrepo") + So(storageConfig.SubPaths["/subpath1"].StorageDriver["type"], ShouldEqual, "modified-s3") + }) + }) + + Convey("Test CopyLogConfig()", func() { + Convey("Test with non-nil Log", func() { + cfg := &config.Config{ + Log: &config.LogConfig{ + Level: "info", + Output: "/tmp/logs", + }, + } + logConfig := cfg.CopyLogConfig() + So(logConfig, ShouldNotBeNil) + So(logConfig.Level, ShouldEqual, "info") + So(logConfig.Output, ShouldEqual, "/tmp/logs") + }) + + Convey("Test with nil Log", func() { + cfg := &config.Config{ + Log: nil, + } + logConfig := cfg.CopyLogConfig() + So(logConfig, ShouldBeNil) + }) + }) + + Convey("Test CopyClusterConfig()", func() { + Convey("Test with non-nil Cluster", func() { + cfg := &config.Config{ + Cluster: &config.ClusterConfig{ + Members: []string{"node1", "node2"}, + }, + } + clusterConfig := cfg.CopyClusterConfig() + So(clusterConfig, ShouldNotBeNil) + So(len(clusterConfig.Members), ShouldEqual, 2) + }) + + Convey("Test with nil Cluster", func() { + cfg := &config.Config{ + Cluster: nil, + } + clusterConfig := cfg.CopyClusterConfig() + So(clusterConfig, ShouldBeNil) + }) + + Convey("Test ClusterConfig deep copy isolation", func() { + cfg := &config.Config{ + Cluster: &config.ClusterConfig{ + Members: []string{"node1", "node2"}, + HashKey: "test-key", + TLS: &config.TLSConfig{ + Cert: "test-cert", + Key: "test-key", + CACert: "test-ca", + }, + Proxy: &config.ClusterRequestProxyConfig{ + LocalMemberClusterSocket: "127.0.0.1:8080", + LocalMemberClusterSocketIndex: 1, + }, + }, + } + + // Get a copy of the cluster config + clusterConfig := cfg.CopyClusterConfig() + So(clusterConfig, ShouldNotBeNil) + + // Mutate the copy + clusterConfig.Members[0] = "modified-node" + clusterConfig.HashKey = "modified-key" + clusterConfig.TLS.Cert = "modified-cert" + clusterConfig.Proxy.LocalMemberClusterSocket = "modified-socket" + + // Verify original config is unchanged + So(cfg.Cluster.Members[0], ShouldEqual, "node1") + So(cfg.Cluster.HashKey, ShouldEqual, "test-key") + So(cfg.Cluster.TLS.Cert, ShouldEqual, "test-cert") + So(cfg.Cluster.Proxy.LocalMemberClusterSocket, ShouldEqual, "127.0.0.1:8080") + + // Verify copy has the mutations + So(clusterConfig.Members[0], ShouldEqual, "modified-node") + So(clusterConfig.HashKey, ShouldEqual, "modified-key") + So(clusterConfig.TLS.Cert, ShouldEqual, "modified-cert") + So(clusterConfig.Proxy.LocalMemberClusterSocket, ShouldEqual, "modified-socket") + }) + }) + + Convey("Test CopySchedulerConfig()", func() { + Convey("Test with non-nil Scheduler", func() { + cfg := &config.Config{ + Scheduler: &config.SchedulerConfig{ + NumWorkers: 4, + }, + } + schedulerConfig := cfg.CopySchedulerConfig() + So(schedulerConfig, ShouldNotBeNil) + So(schedulerConfig.NumWorkers, ShouldEqual, 4) + }) + + Convey("Test with nil Scheduler", func() { + cfg := &config.Config{ + Scheduler: nil, + } + schedulerConfig := cfg.CopySchedulerConfig() + So(schedulerConfig, ShouldBeNil) + }) + }) + + Convey("Test GetVersionInfo()", func() { + Convey("Test with non-nil version info", func() { + cfg := &config.Config{ + Commit: "abc123", + BinaryType: "server", + GoVersion: "go1.21", + DistSpecVersion: "1.1.1", + } + commit, binaryType, goVersion, distSpecVersion := cfg.GetVersionInfo() + So(commit, ShouldEqual, "abc123") + So(binaryType, ShouldEqual, "server") + So(goVersion, ShouldEqual, "go1.21") + So(distSpecVersion, ShouldEqual, "1.1.1") + }) + + Convey("Test with empty version info", func() { + cfg := &config.Config{ + Commit: "", + BinaryType: "", + GoVersion: "", + DistSpecVersion: "", + } + commit, binaryType, goVersion, distSpecVersion := cfg.GetVersionInfo() + So(commit, ShouldEqual, "") + So(binaryType, ShouldEqual, "") + So(goVersion, ShouldEqual, "") + So(distSpecVersion, ShouldEqual, "") + }) + }) + + Convey("Test GetRealm()", func() { + Convey("Test with non-empty Realm", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Realm: "my-realm", + }, + } + realm := cfg.GetRealm() + So(realm, ShouldEqual, "my-realm") + }) + + Convey("Test with empty Realm", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Realm: "", + }, + } + realm := cfg.GetRealm() + So(realm, ShouldEqual, "") + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + realm := cfg.GetRealm() + So(realm, ShouldEqual, "") + }) + }) + + Convey("Test CopyTLSConfig()", func() { + Convey("Test with non-empty TLS config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca.pem", + }, + }, + } + tlsConfig := cfg.CopyTLSConfig() + So(tlsConfig, ShouldNotBeNil) + So(tlsConfig.Cert, ShouldEqual, "/path/to/cert.pem") + So(tlsConfig.Key, ShouldEqual, "/path/to/key.pem") + So(tlsConfig.CACert, ShouldEqual, "/path/to/ca.pem") + + // Test copy isolation + tlsConfig.Cert = "/modified/cert.pem" + + So(cfg.HTTP.TLS.Cert, ShouldEqual, "/path/to/cert.pem") + }) + + Convey("Test with nil TLS config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + TLS: nil, + }, + } + tlsConfig := cfg.CopyTLSConfig() + So(tlsConfig, ShouldBeNil) + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + tlsConfig := cfg.CopyTLSConfig() + So(tlsConfig, ShouldBeNil) + }) + }) + + Convey("Test GetCompat()", func() { + Convey("Test with non-empty compat config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Compat: []compat.MediaCompatibility{ + "docker2s2", + "oci1", + }, + }, + } + compatConfig := cfg.GetCompat() + So(compatConfig, ShouldNotBeNil) + So(len(compatConfig), ShouldEqual, 2) + So(string(compatConfig[0]), ShouldEqual, "docker2s2") + So(string(compatConfig[1]), ShouldEqual, "oci1") + + // Test copy isolation + compatConfig[0] = "modified-compat" + + So(string(cfg.HTTP.Compat[0]), ShouldEqual, "docker2s2") + }) + + Convey("Test with nil compat config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Compat: nil, + }, + } + compatConfig := cfg.GetCompat() + So(compatConfig, ShouldBeNil) + }) + + Convey("Test with empty compat config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Compat: []compat.MediaCompatibility{}, + }, + } + compatConfig := cfg.GetCompat() + So(compatConfig, ShouldNotBeNil) + So(len(compatConfig), ShouldEqual, 0) + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + compatConfig := cfg.GetCompat() + So(compatConfig, ShouldBeNil) + }) + }) + + Convey("Test GetHTTPAddress()", func() { + Convey("Test with non-empty address", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Address: "192.168.1.100", + }, + } + address := cfg.GetHTTPAddress() + So(address, ShouldEqual, "192.168.1.100") + }) + + Convey("Test with empty address", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Address: "", + }, + } + address := cfg.GetHTTPAddress() + So(address, ShouldEqual, "") + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + address := cfg.GetHTTPAddress() + So(address, ShouldEqual, "") + }) + }) + + Convey("Test GetHTTPPort()", func() { + Convey("Test with non-empty port", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Port: "8080", + }, + } + port := cfg.GetHTTPPort() + So(port, ShouldEqual, "8080") + }) + + Convey("Test with empty port", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Port: "", + }, + } + port := cfg.GetHTTPPort() + So(port, ShouldEqual, "") + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + port := cfg.GetHTTPPort() + So(port, ShouldEqual, "") + }) + }) + + Convey("Test GetAllowOrigin()", func() { + Convey("Test with non-empty allow origin", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AllowOrigin: "http://localhost:3000,https://example.com", + }, + } + allowOrigin := cfg.GetAllowOrigin() + So(allowOrigin, ShouldEqual, "http://localhost:3000,https://example.com") + }) + + Convey("Test with empty allow origin", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AllowOrigin: "", + }, + } + allowOrigin := cfg.GetAllowOrigin() + So(allowOrigin, ShouldEqual, "") + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + allowOrigin := cfg.GetAllowOrigin() + So(allowOrigin, ShouldEqual, "") + }) + }) + + Convey("Test CopyRatelimit()", func() { + Convey("Test with non-empty ratelimit config", func() { + rate := 100 + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Ratelimit: &config.RatelimitConfig{ + Rate: &rate, + Methods: []config.MethodRatelimitConfig{ + { + Method: "GET", + Rate: 50, + }, + { + Method: "POST", + Rate: 25, + }, + }, + }, + }, + } + ratelimitConfig := cfg.CopyRatelimit() + So(ratelimitConfig, ShouldNotBeNil) + So(*ratelimitConfig.Rate, ShouldEqual, 100) + So(len(ratelimitConfig.Methods), ShouldEqual, 2) + So(ratelimitConfig.Methods[0].Method, ShouldEqual, "GET") + So(ratelimitConfig.Methods[0].Rate, ShouldEqual, 50) + So(ratelimitConfig.Methods[1].Method, ShouldEqual, "POST") + So(ratelimitConfig.Methods[1].Rate, ShouldEqual, 25) + + // Test deep copy isolation + *ratelimitConfig.Rate = 200 + ratelimitConfig.Methods[0].Rate = 75 + ratelimitConfig.Methods[0].Method = "PUT" + + So(*cfg.HTTP.Ratelimit.Rate, ShouldEqual, 100) + So(cfg.HTTP.Ratelimit.Methods[0].Rate, ShouldEqual, 50) + So(cfg.HTTP.Ratelimit.Methods[0].Method, ShouldEqual, "GET") + }) + + Convey("Test with nil ratelimit config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Ratelimit: nil, + }, + } + ratelimitConfig := cfg.CopyRatelimit() + So(ratelimitConfig, ShouldBeNil) + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + ratelimitConfig := cfg.CopyRatelimit() + So(ratelimitConfig, ShouldBeNil) + }) + }) + }) + + Convey("Test Config utility methods", t, func() { + Convey("Test IsMTLSAuthEnabled()", func() { + // Test with nil Config + var cfg *config.Config = nil + + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) + + // Test with Config but no TLS + cfg = &config.Config{} + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) + + // Test with Config and TLS but no client cert + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) + + // Test with Config and TLS with CA cert (mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) + + // Test with HTPasswd enabled (should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{ + Path: "/path/to/htpasswd", + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with LDAP enabled (should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: &config.LDAPConfig{}, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with API Key enabled (should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + APIKey: true, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with OpenID enabled (valid config - should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "client-id", + Issuer: "", + Scopes: []string{}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with OpenID enabled (with Issuer - should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "https://accounts.google.com", + Scopes: []string{}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with OpenID enabled (with Scopes - should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "", + Scopes: []string{"openid", "email"}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with OAuth2 provider (github) with ClientID (should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "github": { + ClientID: "github-client-id", + Scopes: []string{}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with OAuth2 provider (github) with Scopes (should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "github": { + ClientID: "", + Scopes: []string{"user:email"}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with OpenID but empty config (should enable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "", + Scopes: []string{}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) // No basic auth, so mTLS enabled + + // Test with OpenID but unsupported provider (should enable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "unsupported": { + ClientID: "client-id", + Scopes: []string{"scope"}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) // No basic auth, so mTLS enabled + + // Test with no authentication methods (should enable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{}, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) // No basic auth, so mTLS enabled + + // Test with nil Auth (should enable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: nil, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) // No basic auth, so mTLS enabled + }) + + Convey("Test IsCompatEnabled()", func() { + // Test with nil Config + var cfg *config.Config = nil + + So(cfg.IsCompatEnabled(), ShouldBeFalse) + + // Test with Config but no Compat + cfg = &config.Config{} + So(cfg.IsCompatEnabled(), ShouldBeFalse) + + // Test with Config and Compat enabled + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Compat: []compat.MediaCompatibility{compat.DockerManifestV2SchemaV2}, + }, + } + So(cfg.IsCompatEnabled(), ShouldBeTrue) + }) + + Convey("Test IsOpenIDSupported()", func() { + // Test with unsupported provider + So(config.IsOpenIDSupported("unsupported"), ShouldBeFalse) + + // Test with supported provider + So(config.IsOpenIDSupported("google"), ShouldBeTrue) + }) + + Convey("Test IsOauth2Supported()", func() { + // Test with unsupported provider + So(config.IsOauth2Supported("unsupported"), ShouldBeFalse) + + // Test with supported provider + So(config.IsOauth2Supported("github"), ShouldBeTrue) + }) + + Convey("Test IsClustered() with nil ClusterConfig", func() { + var clusterConfig *config.ClusterConfig = nil + + So(clusterConfig.IsClustered(), ShouldBeFalse) + }) + + Convey("Test IsClustered() with empty members", func() { + clusterConfig := &config.ClusterConfig{ + Members: []string{}, + } + So(clusterConfig.IsClustered(), ShouldBeFalse) + }) + + Convey("Test IsClustered() with single member", func() { + clusterConfig := &config.ClusterConfig{ + Members: []string{"node1:8080"}, + } + So(clusterConfig.IsClustered(), ShouldBeFalse) + }) + + Convey("Test IsClustered() with multiple members", func() { + clusterConfig := &config.ClusterConfig{ + Members: []string{"node1:8080", "node2:8080"}, + } + So(clusterConfig.IsClustered(), ShouldBeTrue) + }) + }) + + Convey("Test CopyExtensionsConfig methods", t, func() { + Convey("Test IsSearchEnabled()", func() { + // Test with nil Config + var cfg *config.Config = nil + + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeFalse) + + // Test with Config but nil Extensions + cfg = &config.Config{} + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeFalse) + + // Test with Config and Extensions but nil Search + cfg = &config.Config{ + Extensions: &extconf.ExtensionConfig{}, + } + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeFalse) + + // Test with Config and Extensions and Search but disabled + disabled := false + cfg = &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &disabled, + }, + }, + }, + } + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeFalse) + + // Test with Config and Extensions and Search enabled + enabled := true + cfg = &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + }, + } + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeTrue) + }) + }) + + Convey("Test UpdateReloadableConfig()", t, func() { + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + newConfig := &config.Config{} + + So(func() { cfg.UpdateReloadableConfig(newConfig) }, ShouldNotPanic) + }) + + Convey("Test with nil newConfig.HTTP.Auth", func() { + // Create initial config with Auth + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + FailDelay: 5, + HTPasswd: config.AuthHTPasswd{ + Path: "/etc/htpasswd", + }, + APIKey: false, + }, + }, + } + + // Create new config with nil Auth + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: nil, // This should not cause a panic + }, + } + + // This should not panic even though newConfig.HTTP.Auth is nil + So(func() { cfg.UpdateReloadableConfig(newConfig) }, ShouldNotPanic) + + // Verify that the original Auth config remains unchanged + So(cfg.HTTP.Auth, ShouldNotBeNil) + So(cfg.HTTP.Auth.FailDelay, ShouldEqual, 5) + So(cfg.HTTP.Auth.HTPasswd.Path, ShouldEqual, "/etc/htpasswd") + So(cfg.HTTP.Auth.APIKey, ShouldBeFalse) + }) + + Convey("Test with AccessControl update", func() { + cfgAccessControl := &config.AccessControlConfig{} + cfgAccessControl.AdminPolicy = config.Policy{ + Actions: []string{"read"}, + } + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: cfgAccessControl, + }, + } + newConfigAccessControl := &config.AccessControlConfig{} + newConfigAccessControl.AdminPolicy = config.Policy{ + Actions: []string{"read", "write"}, + } + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: newConfigAccessControl, + }, + } + cfg.UpdateReloadableConfig(newConfig) + So(cfg.CopyAccessControlConfig().GetAdminPolicy().Actions, ShouldResemble, []string{"read", "write"}) + }) + + Convey("Test with Extensions update", func() { + // First set up a config with search enabled + enabled := true + cfg := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + }, + } + + // Create new config with CVE config + newConfig := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + CVE: &extconf.CVEConfig{ + UpdateInterval: time.Hour * 2, + }, + }, + }, + } + cfg.UpdateReloadableConfig(newConfig) + // The search should still be enabled and CVE config should be updated + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeTrue) + }) + + Convey("Test search CVE config removal when new config has nil Search.CVE", func() { + // First set up a config with search enabled and CVE config + enabled := true + cfg := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + CVE: &extconf.CVEConfig{ + UpdateInterval: time.Hour, + }, + }, + }, + } + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeTrue) + So(cfg.Extensions.Search.CVE, ShouldNotBeNil) + + // Create new config with Search but nil CVE + newConfig := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + CVE: nil, // This should trigger the removal + }, + }, + } + cfg.UpdateReloadableConfig(newConfig) + + // Verify that the CVE config was removed + So(cfg.Extensions.Search.CVE, ShouldBeNil) + So(cfg.Extensions.Search.Enable, ShouldNotBeNil) + So(*cfg.Extensions.Search.Enable, ShouldBeTrue) + }) + + Convey("Test search CVE config removal when new config has nil Search", func() { + // First set up a config with search enabled and CVE config + enabled := true + cfg := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + CVE: &extconf.CVEConfig{ + UpdateInterval: time.Hour, + }, + }, + }, + } + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeTrue) + So(cfg.Extensions.Search.CVE, ShouldNotBeNil) + + // Create new config with Extensions but nil Search + newConfig := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: nil, // This should trigger the removal + }, + } + cfg.UpdateReloadableConfig(newConfig) + + // Verify that the CVE config was removed + So(cfg.Extensions.Search.CVE, ShouldBeNil) + So(cfg.Extensions.Search.Enable, ShouldNotBeNil) + So(*cfg.Extensions.Search.Enable, ShouldBeTrue) + }) + }) + + Convey("Test isOpenIDAuthProviderEnabled indirectly via IsMTLSAuthEnabled", t, func() { + Convey("Test with OpenID provider with empty config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + TLS: &config.TLSConfig{ + Key: "key", + Cert: "cert", + CACert: "cacert", + }, + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "", + Scopes: []string{}, + }, + }, + }, + }, + AccessControl: &config.AccessControlConfig{}, + }, + } + // This should return true because isOpenIDAuthProviderEnabled returns false for empty config, + // so isBasicAuthnEnabled returns false, making IsMTLSAuthEnabled return true + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) + }) + + Convey("Test with OpenID provider with valid config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + TLS: &config.TLSConfig{ + Key: "key", + Cert: "cert", + CACert: "cacert", + }, + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "client-id", + Issuer: "", + Scopes: []string{}, + }, + }, + }, + }, + AccessControl: &config.AccessControlConfig{}, + }, + } + // This should return false because isOpenIDAuthProviderEnabled returns true for valid config, + // so isBasicAuthnEnabled returns true, making IsMTLSAuthEnabled return false + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) + }) + + Convey("Test with unsupported OpenID provider", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + TLS: &config.TLSConfig{ + Key: "key", + Cert: "cert", + CACert: "cacert", + }, + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "unsupported": { + ClientID: "client-id", + Scopes: []string{"scope"}, + }, + }, + }, + }, + AccessControl: &config.AccessControlConfig{}, + }, + } + // This should return true because isOpenIDAuthProviderEnabled returns false for unsupported provider, + // so isBasicAuthnEnabled returns false, making IsMTLSAuthEnabled return true + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) + }) + }) + + Convey("Test nil receiver coverage for all methods", t, func() { + Convey("Test AuthConfig methods with nil receiver", func() { + var authConfig *config.AuthConfig = nil + + So(authConfig.IsLdapAuthEnabled(), ShouldBeFalse) + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeFalse) + So(authConfig.IsBearerAuthEnabled(), ShouldBeFalse) + So(authConfig.IsOpenIDAuthEnabled(), ShouldBeFalse) + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) + So(authConfig.IsBasicAuthnEnabled(), ShouldBeFalse) + So(authConfig.GetFailDelay(), ShouldEqual, 0) + }) + + Convey("Test LDAPConfig methods with nil receiver", func() { + var ldapConfig *config.LDAPConfig = nil + + So(ldapConfig.BindDN(), ShouldEqual, "") + So(ldapConfig.BindPassword(), ShouldEqual, "") + So(ldapConfig.SetBindDN("test"), ShouldBeNil) + So(ldapConfig.SetBindPassword("test"), ShouldBeNil) + }) + + Convey("Test AccessControlConfig methods with nil receiver", func() { + var accessControlConfig *config.AccessControlConfig = nil + + So(accessControlConfig.IsAuthzEnabled(), ShouldBeFalse) + So(accessControlConfig.AnonymousPolicyExists(), ShouldBeFalse) + So(accessControlConfig.ContainsOnlyAnonymousPolicy(), ShouldBeTrue) + + // Test getter methods + So(accessControlConfig.GetRepositories(), ShouldBeNil) + So(accessControlConfig.GetAdminPolicy(), ShouldResemble, config.Policy{}) + So(accessControlConfig.GetMetrics(), ShouldResemble, config.Metrics{}) + So(accessControlConfig.GetGroups(), ShouldBeNil) + }) + + Convey("Test Config methods with nil receiver", func() { + var cfg *config.Config = nil + + // Test getter methods + So(cfg.CopyAuthConfig(), ShouldBeNil) + So(cfg.CopyAccessControlConfig(), ShouldBeNil) + So(cfg.GetHTTPAddress(), ShouldEqual, "") + So(cfg.GetHTTPPort(), ShouldEqual, "") + So(cfg.GetAllowOrigin(), ShouldEqual, "") + So(cfg.CopyTLSConfig(), ShouldBeNil) + So(cfg.CopyRatelimit(), ShouldBeNil) + So(cfg.GetCompat(), ShouldBeNil) + So(cfg.CopyStorageConfig(), ShouldResemble, config.GlobalStorageConfig{}) + So(cfg.CopyExtensionsConfig(), ShouldBeNil) + So(cfg.CopyLogConfig(), ShouldBeNil) + So(cfg.CopyClusterConfig(), ShouldBeNil) + So(cfg.CopySchedulerConfig(), ShouldBeNil) + + // Test GetVersionInfo + commit, binaryType, goVersion, distSpecVersion := cfg.GetVersionInfo() + So(commit, ShouldEqual, "") + So(binaryType, ShouldEqual, "") + So(goVersion, ShouldEqual, "") + So(distSpecVersion, ShouldEqual, "") + + // Test boolean methods + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) + So(cfg.IsRetentionEnabled(), ShouldBeFalse) + So(cfg.IsCompatEnabled(), ShouldBeFalse) + + // Test Sanitize + So(cfg.Sanitize(), ShouldBeNil) + + // Test UpdateReloadableConfig (should not panic) + newConfig := &config.Config{} + + So(func() { cfg.UpdateReloadableConfig(newConfig) }, ShouldNotPanic) + }) + }) + + Convey("Test AccessControlConfig copy isolation through CopyAccessControlConfig()", t, func() { + Convey("Test that mutations to retrieved AccessControlConfig copy do not affect original config", func() { + // Create a config with initial AccessControlConfig + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: &config.AccessControlConfig{ + AdminPolicy: config.Policy{ + Actions: []string{"read"}, + Users: []string{"admin"}, + }, + Repositories: config.Repositories{ + "repo1": config.PolicyGroup{ + DefaultPolicy: []string{"read"}, + Policies: []config.Policy{ + { + Actions: []string{"read"}, + }, + }, + }, + }, + }, + }, + } + + // Retrieve the AccessControlConfig (should be a copy) + accessControlConfig := cfg.CopyAccessControlConfig() + So(accessControlConfig, ShouldNotBeNil) + + // Mutate the retrieved AccessControlConfig copy + accessControlConfig.AdminPolicy = config.Policy{ + Actions: []string{"read", "write", "delete"}, + Users: []string{"admin", "superadmin"}, + } + + // Add a new repository to the copy + newRepositories := config.Repositories{ + "repo1": config.PolicyGroup{ + DefaultPolicy: []string{"read"}, + Policies: []config.Policy{ + { + Actions: []string{"read"}, + }, + }, + }, + "repo2": config.PolicyGroup{ + DefaultPolicy: []string{"read", "write"}, + Policies: []config.Policy{ + { + Actions: []string{"read", "write"}, + Users: []string{"user1"}, + }, + }, + }, + } + accessControlConfig.Repositories = newRepositories + + // Verify that the original config is unchanged + originalAccessControlConfig := cfg.CopyAccessControlConfig() + So(originalAccessControlConfig, ShouldNotBeNil) + + // Check that admin policy remains unchanged in original + adminPolicy := originalAccessControlConfig.GetAdminPolicy() + So(adminPolicy.Actions, ShouldResemble, []string{"read"}) + So(adminPolicy.Users, ShouldResemble, []string{"admin"}) + + // Check that repositories remain unchanged in original + repositories := originalAccessControlConfig.GetRepositories() + So(len(repositories), ShouldEqual, 1) + So(repositories["repo1"], ShouldNotBeNil) + So(repositories["repo1"].DefaultPolicy, ShouldResemble, []string{"read"}) + }) + + Convey("Test that mutations to retrieved AccessControlConfig copy work with nil initial config", func() { + // Create a config with nil AccessControlConfig + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: nil, + }, + } + + // Retrieve the AccessControlConfig (should return nil) + accessControlConfig := cfg.CopyAccessControlConfig() + So(accessControlConfig, ShouldBeNil) + + // Create a new AccessControlConfig and set it + newAccessControlConfig := &config.AccessControlConfig{} + newAccessControlConfig.AdminPolicy = config.Policy{ + Actions: []string{"read"}, + Users: []string{"admin"}, + } + + // Manually set the AccessControlConfig on the original config + cfg.HTTP.AccessControl = newAccessControlConfig + + // Now retrieve it again and verify it works + retrievedConfig := cfg.CopyAccessControlConfig() + So(retrievedConfig, ShouldNotBeNil) + + // Mutate the retrieved config copy + retrievedConfig.AdminPolicy = config.Policy{ + Actions: []string{"read", "write"}, + Users: []string{"admin", "user"}, + } + + // Verify the original config is unchanged + finalConfig := cfg.CopyAccessControlConfig() + adminPolicy := finalConfig.GetAdminPolicy() + So(adminPolicy.Actions, ShouldResemble, []string{"read"}) + So(adminPolicy.Users, ShouldResemble, []string{"admin"}) + }) + }) + + Convey("Test AccessControlConfig copy isolation through UpdateReloadableConfig()", t, func() { + Convey("Test that AccessControlConfig copies are isolated from UpdateReloadableConfig changes", func() { + // Create initial config with AccessControlConfig + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: &config.AccessControlConfig{ + AdminPolicy: config.Policy{ + Actions: []string{"read"}, + Users: []string{"admin"}, + }, + Repositories: config.Repositories{ + "repo1": config.PolicyGroup{ + DefaultPolicy: []string{"read"}, + Policies: []config.Policy{ + { + Actions: []string{"read"}, + }, + }, + }, + }, + }, + }, + } + + // Get initial reference to AccessControlConfig + initialAccessControlConfig := cfg.CopyAccessControlConfig() + So(initialAccessControlConfig, ShouldNotBeNil) + + // Verify initial state + initialAdminPolicy := initialAccessControlConfig.GetAdminPolicy() + So(initialAdminPolicy.Actions, ShouldResemble, []string{"read"}) + So(initialAdminPolicy.Users, ShouldResemble, []string{"admin"}) + + initialRepositories := initialAccessControlConfig.GetRepositories() + So(len(initialRepositories), ShouldEqual, 1) + So(initialRepositories["repo1"], ShouldNotBeNil) + + // Create new config with updated AccessControlConfig + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: &config.AccessControlConfig{ + AdminPolicy: config.Policy{ + Actions: []string{"read", "write", "delete"}, + Users: []string{"admin", "superadmin", "user"}, + }, + Repositories: config.Repositories{ + "repo1": config.PolicyGroup{ + DefaultPolicy: []string{"read", "write"}, + Policies: []config.Policy{ + { + Actions: []string{"read", "write"}, + }, + }, + }, + "repo2": config.PolicyGroup{ + DefaultPolicy: []string{"read"}, + Policies: []config.Policy{ + { + Actions: []string{"read"}, + Users: []string{"user1", "user2"}, + }, + }, + }, + }, + }, + }, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that the old copy remains unchanged (copy isolation) + updatedAdminPolicy := initialAccessControlConfig.GetAdminPolicy() + So(updatedAdminPolicy.Actions, ShouldResemble, []string{"read"}) + So(updatedAdminPolicy.Users, ShouldResemble, []string{"admin"}) + + updatedRepositories := initialAccessControlConfig.GetRepositories() + So(len(updatedRepositories), ShouldEqual, 1) + So(updatedRepositories["repo1"], ShouldNotBeNil) + So(updatedRepositories["repo1"].DefaultPolicy, ShouldResemble, []string{"read"}) + + // Verify that a new copy gets the updated data + newAccessControlConfig := cfg.CopyAccessControlConfig() + So(newAccessControlConfig, ShouldNotBeNil) + So(newAccessControlConfig, ShouldNotEqual, initialAccessControlConfig) // Different copy + + newAdminPolicy := newAccessControlConfig.GetAdminPolicy() + So(newAdminPolicy.Actions, ShouldResemble, []string{"read", "write", "delete"}) + So(newAdminPolicy.Users, ShouldResemble, []string{"admin", "superadmin", "user"}) + }) + + Convey("Test that old AccessControlConfig reference works with nil initial config", func() { + // Create config with nil AccessControlConfig + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: nil, + }, + } + + // Get initial reference (should be nil) + initialAccessControlConfig := cfg.CopyAccessControlConfig() + So(initialAccessControlConfig, ShouldBeNil) + + // Create new config with AccessControlConfig + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: &config.AccessControlConfig{ + AdminPolicy: config.Policy{ + Actions: []string{"read", "write"}, + Users: []string{"admin"}, + }, + }, + }, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that a new reference now gets the data + newAccessControlConfig := cfg.CopyAccessControlConfig() + So(newAccessControlConfig, ShouldNotBeNil) + + adminPolicy := newAccessControlConfig.GetAdminPolicy() + So(adminPolicy.Actions, ShouldResemble, []string{"read", "write"}) + So(adminPolicy.Users, ShouldResemble, []string{"admin"}) + }) + + Convey("Test that old AccessControlConfig reference works when new config has nil AccessControlConfig", func() { + // Create initial config with AccessControlConfig + testAccessControlConfig := &config.AccessControlConfig{} + testAccessControlConfig.AdminPolicy = config.Policy{ + Actions: []string{"read"}, + Users: []string{"admin"}, + } + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: testAccessControlConfig, + }, + } + + // Get initial reference + initialAccessControlConfig := cfg.CopyAccessControlConfig() + So(initialAccessControlConfig, ShouldNotBeNil) + + // Create new config with nil AccessControlConfig + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: nil, + }, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that a new reference now returns nil + newAccessControlConfig := cfg.CopyAccessControlConfig() + So(newAccessControlConfig, ShouldBeNil) + }) + }) + + Convey("Test ExtensionConfig copy isolation through CopyExtensionsConfig()", t, func() { + Convey("Test that mutations to retrieved ExtensionConfig copy do not affect original config", func() { + // Create a config with initial ExtensionConfig + enabled := true + cfg := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + CVE: &extconf.CVEConfig{ + UpdateInterval: time.Hour, + Trivy: &extconf.TrivyConfig{ + DBRepository: "original/trivy-db", + }, + }, + }, + Sync: &syncconf.Config{ + Enable: &enabled, + Registries: []syncconf.RegistryConfig{ + { + URLs: []string{"http://original:5000"}, + }, + }, + }, + Metrics: &extconf.MetricsConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Prometheus: &extconf.PrometheusConfig{ + Path: "/metrics", + }, + }, + Scrub: &extconf.ScrubConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Interval: 24 * time.Hour, + }, + UI: &extconf.UIConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + }, + } + + // Retrieve the ExtensionConfig + extensionConfig := cfg.CopyExtensionsConfig() + So(extensionConfig, ShouldNotBeNil) + + // Mutate the retrieved ExtensionConfig copy + disabled := false + extensionConfig.Search.Enable = &disabled + extensionConfig.Search.CVE.UpdateInterval = 2 * time.Hour + extensionConfig.Search.CVE.Trivy.DBRepository = "modified/trivy-db" + extensionConfig.Sync.Registries[0].URLs[0] = "http://modified:5000" + extensionConfig.Metrics.Prometheus.Path = "/custom/metrics" + extensionConfig.Scrub.Interval = 48 * time.Hour + extensionConfig.UI.Enable = &disabled + + // Verify that the original config is unchanged + So(*cfg.Extensions.Search.Enable, ShouldBeTrue) + So(cfg.Extensions.Search.CVE.UpdateInterval, ShouldEqual, time.Hour) + So(cfg.Extensions.Search.CVE.Trivy.DBRepository, ShouldEqual, "original/trivy-db") + So(cfg.Extensions.Sync.Registries[0].URLs[0], ShouldEqual, "http://original:5000") + So(cfg.Extensions.Metrics.Prometheus.Path, ShouldEqual, "/metrics") + So(cfg.Extensions.Scrub.Interval, ShouldEqual, 24*time.Hour) + So(*cfg.Extensions.UI.Enable, ShouldBeTrue) + + // Verify that the retrieved config has the mutations + So(*extensionConfig.Search.Enable, ShouldBeFalse) + So(extensionConfig.Search.CVE.UpdateInterval, ShouldEqual, 2*time.Hour) + So(extensionConfig.Search.CVE.Trivy.DBRepository, ShouldEqual, "modified/trivy-db") + So(extensionConfig.Sync.Registries[0].URLs[0], ShouldEqual, "http://modified:5000") + So(extensionConfig.Metrics.Prometheus.Path, ShouldEqual, "/custom/metrics") + So(extensionConfig.Scrub.Interval, ShouldEqual, 48*time.Hour) + So(*extensionConfig.UI.Enable, ShouldBeFalse) + }) + + Convey("Test that mutations to retrieved ExtensionConfig work with nil initial config", func() { + // Create a config with nil ExtensionConfig + cfg := &config.Config{ + Extensions: nil, + } + + // Retrieve the ExtensionConfig (should return nil) + extensionConfig := cfg.CopyExtensionsConfig() + So(extensionConfig, ShouldBeNil) + + // Create a new ExtensionConfig and set it + enabled := true + newExtensionConfig := &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + Metrics: &extconf.MetricsConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Prometheus: &extconf.PrometheusConfig{ + Path: "/metrics", + }, + }, + } + + // Manually set the ExtensionConfig on the original config + cfg.Extensions = newExtensionConfig + + // Now retrieve it again and verify it works + retrievedConfig := cfg.CopyExtensionsConfig() + So(retrievedConfig, ShouldNotBeNil) + + // Mutate the retrieved config + retrievedConfig.Metrics.Prometheus.Path = "/new/metrics" + + // Verify the changes are NOT reflected in original config + finalConfig := cfg.CopyExtensionsConfig() + So(finalConfig.Metrics.Prometheus.Path, ShouldEqual, "/metrics") + }) + }) + + Convey("Test ExtensionConfig copy isolation through UpdateReloadableConfig()", t, func() { + Convey("Test that ExtensionConfig copies are isolated from UpdateReloadableConfig changes", func() { + // Create initial config with ExtensionConfig + enabled := true + cfg := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + Metrics: &extconf.MetricsConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Prometheus: &extconf.PrometheusConfig{ + Path: "/metrics", + }, + }, + }, + } + + // Get initial reference to ExtensionConfig + initialExtensionConfig := cfg.CopyExtensionsConfig() + So(initialExtensionConfig, ShouldNotBeNil) + + // Verify initial state + So(initialExtensionConfig.Metrics.Prometheus.Path, ShouldEqual, "/metrics") + So(initialExtensionConfig.Sync, ShouldBeNil) + So(initialExtensionConfig.Search.CVE, ShouldBeNil) + So(initialExtensionConfig.Scrub, ShouldBeNil) + + // Create new config with updated ExtensionConfig + newConfig := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + CVE: &extconf.CVEConfig{ + UpdateInterval: time.Hour * 2, + Trivy: &extconf.TrivyConfig{ + DBRepository: "updated/trivy-db", + }, + }, + }, + Metrics: &extconf.MetricsConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Prometheus: &extconf.PrometheusConfig{ + Path: "/custom/metrics", + }, + }, + Sync: &syncconf.Config{ + Enable: &enabled, + Registries: []syncconf.RegistryConfig{ + { + URLs: []string{"http://registry1:5000", "http://registry2:5000"}, + }, + }, + }, + Scrub: &extconf.ScrubConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Interval: time.Hour * 12, + }, + }, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that the old reference remains unchanged (copy isolation) + So(initialExtensionConfig.Metrics.Prometheus.Path, ShouldEqual, "/metrics") + So(initialExtensionConfig.Sync, ShouldBeNil) + So(initialExtensionConfig.Search.CVE, ShouldBeNil) + So(initialExtensionConfig.Scrub, ShouldBeNil) + + // Verify that a new reference gets the updated data + newExtensionConfig := cfg.CopyExtensionsConfig() + So(newExtensionConfig, ShouldNotBeNil) + So(newExtensionConfig, ShouldNotEqual, initialExtensionConfig) // Different references + + So(newExtensionConfig.Metrics.Prometheus.Path, ShouldEqual, "/metrics") + So(newExtensionConfig.Sync, ShouldNotBeNil) + So(newExtensionConfig.Search.CVE, ShouldNotBeNil) + So(newExtensionConfig.Scrub, ShouldNotBeNil) + }) + + Convey("Test that old ExtensionConfig reference works with nil initial config", func() { + // Create config with nil ExtensionConfig + cfg := &config.Config{ + Extensions: nil, + } + + // Get initial reference (should be nil) + initialExtensionConfig := cfg.CopyExtensionsConfig() + So(initialExtensionConfig, ShouldBeNil) + + // Create new config with ExtensionConfig + enabled := true + newConfig := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + Metrics: &extconf.MetricsConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Prometheus: &extconf.PrometheusConfig{ + Path: "/new/metrics", + }, + }, + }, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that a new reference now gets the data + newExtensionConfig := cfg.CopyExtensionsConfig() + So(newExtensionConfig, ShouldNotBeNil) + + // Note: UpdateReloadableConfig creates an empty ExtensionConfig when going from nil to non-nil, + // but doesn't copy the fields from newConfig.Extensions. It only updates specific parts. + // So the Search and Metrics fields will be nil in the new ExtensionConfig. + So(newExtensionConfig.Search, ShouldBeNil) + So(newExtensionConfig.Metrics, ShouldBeNil) + }) + + Convey("Test that old ExtensionConfig reference works when new config has nil ExtensionConfig", func() { + // Create initial config with ExtensionConfig + enabled := true + cfg := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + }, + } + + // Get initial reference + initialExtensionConfig := cfg.CopyExtensionsConfig() + So(initialExtensionConfig, ShouldNotBeNil) + + // Create new config with nil ExtensionConfig + newConfig := &config.Config{ + Extensions: nil, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that the old reference remains unchanged (copy isolation) + So(initialExtensionConfig, ShouldNotBeNil) + So(initialExtensionConfig.Search, ShouldNotBeNil) + + // Verify that a new reference now returns nil + newExtensionConfig := cfg.CopyExtensionsConfig() + So(newExtensionConfig, ShouldBeNil) + }) + }) + + Convey("Test UpdateReloadableConfig LDAP config updates", t, func() { + Convey("Test LDAP config is updated in UpdateReloadableConfig", func() { + // Create initial config with LDAP + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: &config.LDAPConfig{ + Address: "ldap://old-server:389", + Port: 389, + Insecure: true, + }, + }, + }, + } + + // Create new config with updated LDAP + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: &config.LDAPConfig{ + Address: "ldap://new-server:636", + Port: 636, + Insecure: false, + StartTLS: true, + }, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify LDAP config was updated + So(cfg.HTTP.Auth.LDAP, ShouldNotBeNil) + So(cfg.HTTP.Auth.LDAP.Address, ShouldEqual, "ldap://new-server:636") + So(cfg.HTTP.Auth.LDAP.Port, ShouldEqual, 636) + So(cfg.HTTP.Auth.LDAP.Insecure, ShouldBeFalse) + So(cfg.HTTP.Auth.LDAP.StartTLS, ShouldBeTrue) + }) + + Convey("Test LDAP config is set to nil when new config has nil LDAP", func() { + // Create initial config with LDAP + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: &config.LDAPConfig{ + Address: "ldap://old-server:389", + }, + }, + }, + } + + // Create new config with nil LDAP + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: nil, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify LDAP config was set to nil + So(cfg.HTTP.Auth.LDAP, ShouldBeNil) + }) + + Convey("Test LDAP config is created when going from nil to non-nil", func() { + // Create initial config with nil LDAP + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: nil, + }, + }, + } + + // Create new config with LDAP + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: &config.LDAPConfig{ + Address: "ldap://new-server:389", + Port: 389, + }, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify LDAP config was created + So(cfg.HTTP.Auth.LDAP, ShouldNotBeNil) + So(cfg.HTTP.Auth.LDAP.Address, ShouldEqual, "ldap://new-server:389") + So(cfg.HTTP.Auth.LDAP.Port, ShouldEqual, 389) + }) + }) + + Convey("Test UpdateReloadableConfig Storage.SubPaths logic", t, func() { + Convey("Test existing SubPaths are updated", func() { + // Create initial config with SubPaths + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: true, + Dedupe: false, + GCDelay: time.Hour, + GCInterval: time.Hour * 24, + }, + "/path2": { + GC: false, + Dedupe: true, + GCDelay: time.Hour * 2, + GCInterval: time.Hour * 48, + }, + }, + }, + } + + // Create new config with updated SubPaths + newConfig := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: false, // Changed + Dedupe: true, // Changed + GCDelay: time.Hour * 2, // Changed + GCInterval: time.Hour * 12, // Changed + }, + "/path2": { + GC: true, // Changed + Dedupe: false, // Changed + GCDelay: time.Hour * 3, // Changed + GCInterval: time.Hour * 36, // Changed + }, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify SubPaths were updated + So(len(cfg.Storage.SubPaths), ShouldEqual, 2) + + // Check /path1 + path1Config := cfg.Storage.SubPaths["/path1"] + So(path1Config.GC, ShouldBeFalse) + So(path1Config.Dedupe, ShouldBeTrue) + So(path1Config.GCDelay, ShouldEqual, time.Hour*2) + So(path1Config.GCInterval, ShouldEqual, time.Hour*12) + + // Check /path2 + path2Config := cfg.Storage.SubPaths["/path2"] + So(path2Config.GC, ShouldBeTrue) + So(path2Config.Dedupe, ShouldBeFalse) + So(path2Config.GCDelay, ShouldEqual, time.Hour*3) + So(path2Config.GCInterval, ShouldEqual, time.Hour*36) + }) + + Convey("Test new SubPaths are not added (only existing ones are updated)", func() { + // Create initial config with one SubPath + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: true, + Dedupe: false, + }, + }, + }, + } + + // Create new config with additional SubPath + newConfig := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: false, // Update existing + Dedupe: true, // Update existing + }, + "/path2": { // New path - should not be added + GC: true, + Dedupe: true, + }, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify only existing SubPath was updated, new one was not added + So(len(cfg.Storage.SubPaths), ShouldEqual, 1) + _, exists := cfg.Storage.SubPaths["/path2"] + So(exists, ShouldBeFalse) // New path not added + + // Verify existing path was updated + path1Config := cfg.Storage.SubPaths["/path1"] + So(path1Config.GC, ShouldBeFalse) + So(path1Config.Dedupe, ShouldBeTrue) + }) + + Convey("Test SubPaths Retention is updated only when retention is enabled", func() { + // Create initial config with retention enabled and SubPaths + // Retention is enabled when there are policies with tag retention + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"repo1"}, + KeepTags: []config.KeepTagsPolicy{ + { + MostRecentlyPulledCount: 10, // This enables retention + }, + }, + }, + }, + }, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: true, + Dedupe: false, + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"old-repo"}, + }, + }, + }, + }, + }, + }, + } + + // Create new config with updated SubPath retention + newConfig := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"repo1"}, + KeepTags: []config.KeepTagsPolicy{ + { + MostRecentlyPulledCount: 10, // This enables retention + }, + }, + }, + }, + }, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: false, + Dedupe: true, + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"new-repo"}, + }, + }, + }, + }, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify SubPath was updated including Retention + path1Config := cfg.Storage.SubPaths["/path1"] + So(path1Config.GC, ShouldBeFalse) + So(path1Config.Dedupe, ShouldBeTrue) + So(len(path1Config.Retention.Policies), ShouldEqual, 1) + So(path1Config.Retention.Policies[0].Repositories[0], ShouldEqual, "new-repo") + }) + + Convey("Test SubPaths Retention is not updated when retention is disabled", func() { + // Create initial config with retention disabled and SubPaths + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + // No Retention config - retention disabled + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: true, + Dedupe: false, + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"old-repo"}, + }, + }, + }, + }, + }, + }, + } + + // Create new config with updated SubPath retention + newConfig := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + // No Retention config - retention disabled + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: false, + Dedupe: true, + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"new-repo"}, + }, + }, + }, + }, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify SubPath was updated but Retention was not + path1Config := cfg.Storage.SubPaths["/path1"] + So(path1Config.GC, ShouldBeFalse) + So(path1Config.Dedupe, ShouldBeTrue) + // Retention should remain unchanged (old value) + So(len(path1Config.Retention.Policies), ShouldEqual, 1) + So(path1Config.Retention.Policies[0].Repositories[0], ShouldEqual, "old-repo") + }) + + Convey("Test SubPaths with empty new config", func() { + // Create initial config with SubPaths + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: true, + Dedupe: false, + }, + "/path2": { + GC: false, + Dedupe: true, + }, + }, + }, + } + + // Create new config with empty SubPaths + newConfig := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + }, + SubPaths: map[string]config.StorageConfig{}, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify existing SubPaths remain unchanged (no updates applied) + So(len(cfg.Storage.SubPaths), ShouldEqual, 2) + path1Config := cfg.Storage.SubPaths["/path1"] + So(path1Config.GC, ShouldBeTrue) // Unchanged + So(path1Config.Dedupe, ShouldBeFalse) // Unchanged + path2Config := cfg.Storage.SubPaths["/path2"] + So(path2Config.GC, ShouldBeFalse) // Unchanged + So(path2Config.Dedupe, ShouldBeTrue) // Unchanged + }) }) } diff --git a/pkg/api/controller.go b/pkg/api/controller.go index 93df422e..30666a7f 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -22,15 +22,14 @@ import ( "zotregistry.dev/zot/v2/pkg/api/config" "zotregistry.dev/zot/v2/pkg/common" ext "zotregistry.dev/zot/v2/pkg/extensions" - extconf "zotregistry.dev/zot/v2/pkg/extensions/config" - "zotregistry.dev/zot/v2/pkg/extensions/events" - "zotregistry.dev/zot/v2/pkg/extensions/monitoring" - "zotregistry.dev/zot/v2/pkg/log" - "zotregistry.dev/zot/v2/pkg/meta" + events "zotregistry.dev/zot/v2/pkg/extensions/events" + monitoring "zotregistry.dev/zot/v2/pkg/extensions/monitoring" + log "zotregistry.dev/zot/v2/pkg/log" + meta "zotregistry.dev/zot/v2/pkg/meta" mTypes "zotregistry.dev/zot/v2/pkg/meta/types" - "zotregistry.dev/zot/v2/pkg/scheduler" - "zotregistry.dev/zot/v2/pkg/storage" - "zotregistry.dev/zot/v2/pkg/storage/gc" + scheduler "zotregistry.dev/zot/v2/pkg/scheduler" + storage "zotregistry.dev/zot/v2/pkg/storage" + gc "zotregistry.dev/zot/v2/pkg/storage/gc" ) const ( @@ -139,12 +138,13 @@ func (c *Controller) Run() error { engine := mux.NewRouter() // rate-limit HTTP requests if enabled - if c.Config.HTTP.Ratelimit != nil { - if c.Config.HTTP.Ratelimit.Rate != nil { - engine.Use(RateLimiter(c, *c.Config.HTTP.Ratelimit.Rate)) + ratelimitConfig := c.Config.CopyRatelimit() + if ratelimitConfig != nil { + if ratelimitConfig.Rate != nil { + engine.Use(RateLimiter(c, *ratelimitConfig.Rate)) } - for _, mrlim := range c.Config.HTTP.Ratelimit.Methods { + for _, mrlim := range ratelimitConfig.Methods { engine.Use(MethodRateLimiter(c, mrlim.Method, mrlim.Rate)) } } @@ -161,13 +161,14 @@ func (c *Controller) Run() error { c.Router = engine c.Router.UseEncodedPath() - monitoring.SetServerInfo(c.Metrics, c.Config.Commit, c.Config.BinaryType, c.Config.GoVersion, - c.Config.DistSpecVersion) + commit, binaryType, goVersion, distSpecVersion := c.Config.GetVersionInfo() + monitoring.SetServerInfo(c.Metrics, commit, binaryType, goVersion, distSpecVersion) //nolint: contextcheck _ = NewRouteHandler(c) - addr := fmt.Sprintf("%s:%s", c.Config.HTTP.Address, c.Config.HTTP.Port) + port := c.Config.GetHTTPPort() + addr := fmt.Sprintf("%s:%s", c.Config.GetHTTPAddress(), port) server := &http.Server{ Addr: addr, Handler: c.Router, @@ -182,10 +183,10 @@ func (c *Controller) Run() error { return err } - if c.Config.HTTP.Port == "0" || c.Config.HTTP.Port == "" { + if port == "0" || port == "" { chosenAddr, ok := listener.Addr().(*net.TCPAddr) if !ok { - c.Log.Error().Str("port", c.Config.HTTP.Port).Msg("invalid addr type") + c.Log.Error().Str("port", port).Msg("invalid addr type") return errors.ErrBadType } @@ -196,12 +197,13 @@ func (c *Controller) Run() error { "port is unspecified, listening on kernel chosen port", ) } else { - chosenPort, _ := strconv.ParseInt(c.Config.HTTP.Port, 10, 64) + chosenPort, _ := strconv.ParseInt(port, 10, 32) c.chosenPort = int(chosenPort) } - if c.Config.HTTP.TLS != nil && c.Config.HTTP.TLS.Key != "" && c.Config.HTTP.TLS.Cert != "" { + tlsConfig := c.Config.CopyTLSConfig() + if tlsConfig != nil && tlsConfig.Key != "" && tlsConfig.Cert != "" { server.TLSConfig = &tls.Config{ CipherSuites: []uint16{ tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, @@ -219,15 +221,15 @@ func (c *Controller) Run() error { MinVersion: tls.VersionTLS12, } - if c.Config.HTTP.TLS.CACert != "" { + if tlsConfig.CACert != "" { clientAuth := tls.VerifyClientCertIfGiven if c.Config.IsMTLSAuthEnabled() { clientAuth = tls.RequireAndVerifyClientCert } - caCert, err := os.ReadFile(c.Config.HTTP.TLS.CACert) + caCert, err := os.ReadFile(tlsConfig.CACert) if err != nil { - c.Log.Error().Err(err).Str("caCert", c.Config.HTTP.TLS.CACert).Msg("failed to read file") + c.Log.Error().Err(err).Str("caCert", tlsConfig.CACert).Msg("failed to read file") return err } @@ -246,7 +248,7 @@ func (c *Controller) Run() error { c.Healthz.Ready() - return server.ServeTLS(listener, c.Config.HTTP.TLS.Cert, c.Config.HTTP.TLS.Key) + return server.ServeTLS(listener, tlsConfig.Cert, tlsConfig.Key) } c.Healthz.Ready() @@ -267,10 +269,9 @@ func (c *Controller) Init() error { DumpRuntimeParams(c.Log) var enabled bool - if c.Config != nil && - c.Config.Extensions != nil && - c.Config.Extensions.Metrics != nil && - *c.Config.Extensions.Metrics.Enable { + extensionsConfig := c.Config.CopyExtensionsConfig() + + if extensionsConfig.IsMetricsEnabled() { enabled = true } @@ -291,8 +292,10 @@ func (c *Controller) Init() error { c.InitCVEInfo() c.Healthz.Started() - if c.Config.IsHtpasswdAuthEnabled() { - err := c.HTPasswdWatcher.ChangeFile(c.Config.HTTP.Auth.HTPasswd.Path) + // Get auth config safely + authConfig := c.Config.CopyAuthConfig() + if authConfig.IsHtpasswdAuthEnabled() { + err := c.HTPasswdWatcher.ChangeFile(authConfig.HTPasswd.Path) if err != nil { return err } @@ -303,9 +306,7 @@ func (c *Controller) Init() error { func (c *Controller) InitCVEInfo() { // Enable CVE extension if extension config is provided - if c.Config != nil && c.Config.Extensions != nil { - c.CveScanner = ext.GetCveScanner(c.Config, c.StoreController, c.MetaDB, c.Log) - } + c.CveScanner = ext.GetCveScanner(c.Config, c.StoreController, c.MetaDB, c.Log) } func (c *Controller) InitImageStore() error { @@ -323,11 +324,12 @@ func (c *Controller) InitImageStore() error { func (c *Controller) initCookieStore() error { // setup sessions cookie store used to preserve logged in user in web sessions - if c.Config.IsBasicAuthnEnabled() { + if c.Config.HTTP.Auth.IsBasicAuthnEnabled() { if c.Config.HTTP.Auth.SessionHashKey == nil { c.Log.Warn().Msg("hashKey is not set in config, generating a random one") - c.Config.HTTP.Auth.SessionHashKey = securecookie.GenerateRandomKey(64) //nolint: gomnd + key := securecookie.GenerateRandomKey(64) //nolint: gomnd + c.Config.HTTP.Auth.SessionHashKey = key } cookieStore, err := NewCookieStore(c.Config.HTTP.Auth, c.StoreController, c.Log) @@ -343,9 +345,16 @@ func (c *Controller) initCookieStore() error { func (c *Controller) InitMetaDB() error { // init metaDB if search is enabled or we need to store user profiles, api keys or signatures - if c.Config.IsSearchEnabled() || c.Config.IsBasicAuthnEnabled() || c.Config.IsImageTrustEnabled() || + // Get auth config safely + authConfig := c.Config.CopyAuthConfig() + extensionsConfig := c.Config.CopyExtensionsConfig() + + if extensionsConfig.IsSearchEnabled() || authConfig.IsBasicAuthnEnabled() || extensionsConfig.IsImageTrustEnabled() || c.Config.IsRetentionEnabled() { - driver, err := meta.New(c.Config.Storage.StorageConfig, c.Log) //nolint:contextcheck + // Get storage config safely + storageConfig := c.Config.CopyStorageConfig() + + driver, err := meta.New(storageConfig.StorageConfig, c.Log) //nolint:contextcheck if err != nil { return err } @@ -383,74 +392,25 @@ func (c *Controller) InitEventRecorder() error { } func (c *Controller) LoadNewConfig(newConfig *config.Config) { - // reload access control config - c.Config.HTTP.AccessControl = newConfig.HTTP.AccessControl + // Update only reloadable config fields atomically + c.Config.UpdateReloadableConfig(newConfig) - if c.Config.HTTP.Auth != nil { - c.Config.HTTP.Auth.HTPasswd = newConfig.HTTP.Auth.HTPasswd - c.Config.HTTP.Auth.LDAP = newConfig.HTTP.Auth.LDAP - - err := c.HTPasswdWatcher.ChangeFile(c.Config.HTTP.Auth.HTPasswd.Path) + // Operations that need to happen after config update + authConfig := c.Config.CopyAuthConfig() + if authConfig.IsHtpasswdAuthEnabled() { + err := c.HTPasswdWatcher.ChangeFile(authConfig.HTPasswd.Path) if err != nil { c.Log.Error().Err(err).Msg("failed to change watched htpasswd file") } - - if c.LDAPClient != nil { - c.LDAPClient.lock.Lock() - c.LDAPClient.BindDN = newConfig.HTTP.Auth.LDAP.BindDN() - c.LDAPClient.BindPassword = newConfig.HTTP.Auth.LDAP.BindPassword() - c.LDAPClient.lock.Unlock() - } } else { _ = c.HTPasswdWatcher.ChangeFile("") } - // reload periodical gc config - c.Config.Storage.GC = newConfig.Storage.GC - c.Config.Storage.Dedupe = newConfig.Storage.Dedupe - c.Config.Storage.GCDelay = newConfig.Storage.GCDelay - c.Config.Storage.GCInterval = newConfig.Storage.GCInterval - // only if we have a metaDB already in place - if c.Config.IsRetentionEnabled() { - c.Config.Storage.Retention = newConfig.Storage.Retention - } - - for subPath, storageConfig := range newConfig.Storage.SubPaths { - subPathConfig, ok := c.Config.Storage.SubPaths[subPath] - if ok { - subPathConfig.GC = storageConfig.GC - subPathConfig.Dedupe = storageConfig.Dedupe - subPathConfig.GCDelay = storageConfig.GCDelay - subPathConfig.GCInterval = storageConfig.GCInterval - // only if we have a metaDB already in place - if c.Config.IsRetentionEnabled() { - subPathConfig.Retention = storageConfig.Retention - } - - c.Config.Storage.SubPaths[subPath] = subPathConfig - } - } - - // reload background tasks - if newConfig.Extensions != nil { - if c.Config.Extensions == nil { - c.Config.Extensions = &extconf.ExtensionConfig{} - } - - // reload sync extension - c.Config.Extensions.Sync = newConfig.Extensions.Sync - - // reload only if search is enabled and reloaded config has search extension (can't setup routes at this stage) - if c.Config.Extensions.Search != nil && *c.Config.Extensions.Search.Enable { - if newConfig.Extensions.Search != nil { - c.Config.Extensions.Search.CVE = newConfig.Extensions.Search.CVE - } - } - - // reload scrub extension - c.Config.Extensions.Scrub = newConfig.Extensions.Scrub - } else { - c.Config.Extensions = nil + if c.LDAPClient != nil && authConfig.IsLdapAuthEnabled() { + c.LDAPClient.lock.Lock() + c.LDAPClient.BindDN = authConfig.LDAP.BindDN() + c.LDAPClient.BindPassword = authConfig.LDAP.BindPassword() + c.LDAPClient.lock.Unlock() } c.InitCVEInfo() @@ -463,6 +423,11 @@ func (c *Controller) Shutdown() { // stop all background tasks c.StopBackgroundTasks() + // Stop metrics server to prevent resource leaks (only during full shutdown) + if c.Metrics != nil { + c.Metrics.Stop() + } + if c.Server != nil { ctx := context.Background() _ = c.Server.Shutdown(ctx) @@ -479,73 +444,90 @@ func (c *Controller) StopBackgroundTasks() { if c.taskScheduler != nil { c.taskScheduler.Shutdown() } + + // Close HTPasswdWatcher to prevent resource leaks + if c.HTPasswdWatcher != nil { + _ = c.HTPasswdWatcher.Close() + } } func (c *Controller) StartBackgroundTasks() { c.taskScheduler = scheduler.NewScheduler(c.Config, c.Metrics, c.Log) c.taskScheduler.RunScheduler() + // Start HTPasswdWatcher goroutine + if c.HTPasswdWatcher != nil { + c.HTPasswdWatcher.Run() + } + // Enable running garbage-collect periodically for DefaultStore - if c.Config.Storage.GC { + storageConfig := c.Config.CopyStorageConfig() + if storageConfig.GC { gc := gc.NewGarbageCollect(c.StoreController.DefaultStore, c.MetaDB, gc.Options{ - Delay: c.Config.Storage.GCDelay, - ImageRetention: c.Config.Storage.Retention, + Delay: storageConfig.GCDelay, + ImageRetention: storageConfig.Retention, }, c.Audit, c.Log) - gc.CleanImageStorePeriodically(c.Config.Storage.GCInterval, c.taskScheduler) + gc.CleanImageStorePeriodically(storageConfig.GCInterval, c.taskScheduler) } // Enable running dedupe blobs both ways (dedupe or restore deduped blobs) c.StoreController.DefaultStore.RunDedupeBlobs(time.Duration(0), c.taskScheduler) // Enable extensions if extension config is provided for DefaultStore - if c.Config != nil && c.Config.Extensions != nil { - ext.EnableMetricsExtension(c.Config, c.Log, c.Config.Storage.RootDirectory) - ext.EnableSearchExtension(c.Config, c.StoreController, c.MetaDB, c.taskScheduler, c.CveScanner, c.Log) - } + extensionsConfig := c.Config.CopyExtensionsConfig() + + // Always call EnableSearchExtension to ensure proper logging, even when search is disabled + ext.EnableSearchExtension(c.Config, c.StoreController, c.MetaDB, c.taskScheduler, c.CveScanner, c.Log) + + // Always call EnableMetricsExtension to ensure proper logging, even when metrics is disabled + ext.EnableMetricsExtension(c.Config, c.Log, storageConfig.RootDirectory) + // runs once if metrics are enabled & imagestore is local - if c.Config.IsMetricsEnabled() && c.Config.Storage.StorageDriver == nil { + if extensionsConfig.IsMetricsEnabled() && storageConfig.StorageDriver == nil { c.StoreController.DefaultStore.PopulateStorageMetrics(time.Duration(0), c.taskScheduler) } - if c.Config.Storage.SubPaths != nil { - for route, storageConfig := range c.Config.Storage.SubPaths { + if storageConfig.SubPaths != nil { + for route, subStorageConfig := range storageConfig.SubPaths { // Enable running garbage-collect periodically for subImageStore - if storageConfig.GC { + if subStorageConfig.GC { gc := gc.NewGarbageCollect(c.StoreController.SubStore[route], c.MetaDB, gc.Options{ - Delay: storageConfig.GCDelay, - ImageRetention: storageConfig.Retention, + Delay: subStorageConfig.GCDelay, + ImageRetention: subStorageConfig.Retention, }, c.Audit, c.Log) - gc.CleanImageStorePeriodically(storageConfig.GCInterval, c.taskScheduler) + gc.CleanImageStorePeriodically(subStorageConfig.GCInterval, c.taskScheduler) } // Enable extensions if extension config is provided for subImageStore - if c.Config != nil && c.Config.Extensions != nil { - ext.EnableMetricsExtension(c.Config, c.Log, storageConfig.RootDirectory) - } + ext.EnableMetricsExtension(c.Config, c.Log, subStorageConfig.RootDirectory) // Enable running dedupe blobs both ways (dedupe or restore deduped blobs) for subpaths substore := c.StoreController.SubStore[route] if substore != nil { substore.RunDedupeBlobs(time.Duration(0), c.taskScheduler) - if c.Config.IsMetricsEnabled() && c.Config.Storage.StorageDriver == nil { + if extensionsConfig.IsMetricsEnabled() && storageConfig.StorageDriver == nil { substore.PopulateStorageMetrics(time.Duration(0), c.taskScheduler) } } } } - if c.Config.Extensions != nil { - ext.EnableScrubExtension(c.Config, c.Log, c.StoreController, c.taskScheduler) - //nolint: contextcheck - syncOnDemand, err := ext.EnableSyncExtension(c.Config, c.MetaDB, c.StoreController, c.taskScheduler, c.Log) - if err != nil { - c.Log.Error().Err(err).Msg("failed to start sync extension") - } + // Always call EnableScrubExtension to ensure proper logging, even when scrub is disabled + ext.EnableScrubExtension(c.Config, c.Log, c.StoreController, c.taskScheduler) + // Always call EnableSyncExtension to ensure proper logging, even when sync is disabled + //nolint: contextcheck + syncOnDemand, err := ext.EnableSyncExtension(c.Config, c.MetaDB, c.StoreController, c.taskScheduler, c.Log) + if err != nil { + c.Log.Error().Err(err).Msg("failed to start sync extension") + } + + // Only set SyncOnDemand if sync is actually enabled + if extensionsConfig.IsSyncEnabled() { c.SyncOnDemand = syncOnDemand } diff --git a/pkg/api/controller_test.go b/pkg/api/controller_test.go index cef57cc5..03c40dd6 100644 --- a/pkg/api/controller_test.go +++ b/pkg/api/controller_test.go @@ -1296,7 +1296,9 @@ func TestScaleOutRequestProxy(t *testing.T) { cm := test.NewControllerManager(ctrlr) cm.StartAndWait(port) - defer cm.StopServer() + defer func(cm test.ControllerManager) { + cm.StopServer() + }(cm) } Convey("All 3 controllers should start up and respond without error", func() { @@ -1393,7 +1395,9 @@ func TestScaleOutRequestProxy(t *testing.T) { cm := test.NewControllerManager(ctrlr) cm.StartAndWait(port) - defer cm.StopServer() + defer func(cm test.ControllerManager) { + cm.StopServer() + }(cm) } Convey("All 3 controllers should start up and respond without error", func() { @@ -1470,7 +1474,9 @@ func TestScaleOutRequestProxy(t *testing.T) { cm := test.NewControllerManager(ctrlr) cm.StartAndWait(port) - defer cm.StopServer() + defer func(cm test.ControllerManager) { + cm.StopServer() + }(cm) } caCert, err := os.ReadFile(CACert) @@ -1534,7 +1540,9 @@ func TestScaleOutRequestProxy(t *testing.T) { cm := test.NewControllerManager(ctrlr) cm.StartAndWait(port) - defer cm.StopServer() + defer func(cm test.ControllerManager) { + cm.StopServer() + }(cm) } caCert, err := os.ReadFile(CACert) @@ -1599,7 +1607,9 @@ func TestScaleOutRequestProxy(t *testing.T) { cm := test.NewControllerManager(ctrlr) cm.StartAndWait(port) - defer cm.StopServer() + defer func(cm test.ControllerManager) { + cm.StopServer() + }(cm) } caCert, err := os.ReadFile(CACert) @@ -7171,6 +7181,8 @@ func TestCrossRepoMount(t *testing.T) { ctlr.Config.Storage.Dedupe = true ctlr.Config.Storage.GC = false ctlr.Config.Storage.RootDirectory = newDir + + ctlr = api.NewController(ctlr.Config) cm = test.NewControllerManager(ctlr) //nolint: varnamelen cm.StartAndWait(port) diff --git a/pkg/api/htpasswd.go b/pkg/api/htpasswd.go index 6a6a6236..5bdbcbdb 100644 --- a/pkg/api/htpasswd.go +++ b/pkg/api/htpasswd.go @@ -100,76 +100,129 @@ func (s *HTPasswd) Authenticate(username, passphrase string) (ok, present bool) // HTPasswdWatcher helper which triggers htpasswd reload on file change event. // -// Cannot be restarted. +// Can be restarted by calling Run() again after Close(). type HTPasswdWatcher struct { htp *HTPasswd filePath string watcher *fsnotify.Watcher + ctx context.Context //nolint:containedctx // Context is needed for watcher lifecycle management cancel context.CancelFunc log log.Logger + mu sync.Mutex } -// NewHTPasswdWatcher create and start watcher. +// NewHTPasswdWatcher creates a new watcher instance. func NewHTPasswdWatcher(htp *HTPasswd, filePath string) (*HTPasswdWatcher, error) { - watcher, err := fsnotify.NewWatcher() - if err != nil { - return nil, err - } - - if filePath != "" { - err = watcher.Add(filePath) - if err != nil { - return nil, errors.Join(err, watcher.Close()) - } - } - - // background event processor job context - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) - ret := &HTPasswdWatcher{ htp: htp, filePath: filePath, - watcher: watcher, - cancel: cancel, log: htp.log, } - go func() { - defer ret.watcher.Close() //nolint: errcheck - - for { - select { - case ev := <-ret.watcher.Events: - if ev.Op != fsnotify.Write { - continue - } - - ret.log.Info().Str("htpasswd-file", ret.filePath).Msg("htpasswd file changed, trying to reload config") - - err := ret.htp.Reload(ret.filePath) - if err != nil { - ret.log.Warn().Err(err).Str("htpasswd-file", ret.filePath).Msg("failed to reload file") - } - - case err := <-ret.watcher.Errors: - ret.log.Error().Err(err).Str("htpasswd-file", ret.filePath).Msg("failed to fsnotfy, got error while watching file") - - case <-ctx.Done(): - ret.log.Debug().Msg("htpasswd watcher terminating...") - - return - } - } - }() - return ret, nil } +// Run starts the watcher goroutine. +func (s *HTPasswdWatcher) Run() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.ctx != nil { + return // Already running + } + + // Create fresh fsnotify watcher for this run + watcher, err := fsnotify.NewWatcher() + if err != nil { + s.log.Error().Err(err).Msg("failed to create fsnotify watcher") + + return + } + + // Only add file to watcher if we have a file to watch + if s.filePath != "" { + err = watcher.Add(s.filePath) + if err != nil { + s.log.Error().Err(err).Str("htpasswd-file", s.filePath).Msg("failed to add file to watcher") + watcher.Close() //nolint: errcheck + + return + } + } + + // Create context and start goroutine + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + s.ctx = ctx + s.cancel = cancel + s.watcher = watcher + + go func() { + defer func() { + s.mu.Lock() + defer s.mu.Unlock() + + // Clean up watcher + if s.watcher != nil { + s.watcher.Close() //nolint: errcheck + s.watcher = nil + } + + // Clear context to indicate not running + s.ctx = nil + s.cancel = nil + }() + + for { + select { + case <-ctx.Done(): + s.log.Debug().Msg("htpasswd watcher terminating...") + + return + + case ev := <-watcher.Events: + if ev.Op != fsnotify.Write { + continue + } + + s.log.Info().Str("htpasswd-file", s.filePath).Msg("htpasswd file changed, trying to reload config") + + err := s.htp.Reload(s.filePath) + if err != nil { + s.log.Warn().Err(err).Str("htpasswd-file", s.filePath).Msg("failed to reload file") + } + + case err := <-watcher.Errors: + // Only log errors if we're actually watching a file + if s.filePath != "" { + s.log.Error().Err(err).Str("htpasswd-file", s.filePath).Msg("failed to fsnotfy, got error while watching file") + } + } + } + }() +} + // ChangeFile changes monitored file. Empty string clears store. func (s *HTPasswdWatcher) ChangeFile(filePath string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // If watcher is not running, just update the filePath for when Run() is called + if s.watcher == nil { + s.filePath = filePath + if filePath == "" { + s.htp.Clear() + } else { + return s.htp.Reload(filePath) + } + + return nil + } + + // Remove old file if it exists if s.filePath != "" { - err := s.watcher.Remove(s.filePath) - if err != nil { + if err := s.watcher.Remove(s.filePath); err != nil && !errors.Is(err, fsnotify.ErrNonExistentWatch) { + // Ignore "can't remove non-existent watch" errors as they can happen + // due to race conditions or files being removed externally return err } } @@ -192,7 +245,21 @@ func (s *HTPasswdWatcher) ChangeFile(filePath string) error { } func (s *HTPasswdWatcher) Close() error { - s.cancel() + s.mu.Lock() + defer s.mu.Unlock() + + if s.ctx == nil { + return nil // Already closed/not running + } + + // Cancel context to signal goroutine to stop + if s.cancel != nil { + s.cancel() + } + + // The goroutine will clean up s.ctx, s.cancel, and s.watcher in its defer + // We just need to wait for it to finish by checking if s.ctx becomes nil + // This is safe because the goroutine sets s.ctx = nil in its defer return nil } diff --git a/pkg/api/htpasswd_test.go b/pkg/api/htpasswd_test.go index f37bd2ff..fdd7f84d 100644 --- a/pkg/api/htpasswd_test.go +++ b/pkg/api/htpasswd_test.go @@ -12,7 +12,7 @@ import ( test "zotregistry.dev/zot/v2/pkg/test/common" ) -func TestHTPasswdWatcher(t *testing.T) { +func TestHTPasswdWatcherOriginal(t *testing.T) { logger := log.NewLogger("DEBUG", "") Convey("reload htpasswd", t, func(c C) { @@ -28,6 +28,9 @@ func TestHTPasswdWatcher(t *testing.T) { htw, err := api.NewHTPasswdWatcher(htp, "") So(err, ShouldBeNil) + // Start the watcher goroutine + htw.Run() + defer htw.Close() //nolint: errcheck _, present := htp.Get(username) @@ -62,3 +65,491 @@ func TestHTPasswdWatcher(t *testing.T) { So(present, ShouldBeTrue) }) } + +func TestHTPasswdWatcher(t *testing.T) { + logger := log.NewLogger("DEBUG", "") + + Convey("Test HTPasswdWatcher comprehensive functionality", t, func() { + Convey("Test basic operations and lifecycle", func() { + // Create a buffer to capture log output + logBuffer, multiWriter := test.CreateLogCapturingWriter(os.Stdout) + capturingLogger := log.NewLoggerWithWriter("debug", multiWriter) + + htp := api.NewHTPasswd(capturingLogger) + htw, err := api.NewHTPasswdWatcher(htp, "") + So(err, ShouldBeNil) + + // Test Run() and Close() operations + So(func() { htw.Run() }, ShouldNotPanic) + time.Sleep(10 * time.Millisecond) + So(func() { htw.Run() }, ShouldNotPanic) // Idempotent + time.Sleep(10 * time.Millisecond) + So(func() { htw.Close() }, ShouldNotPanic) + time.Sleep(10 * time.Millisecond) + So(htw.Close(), ShouldBeNil) // Idempotent + + // Verify goroutine termination + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + }) + + Convey("Test ChangeFile() operations and file watching", func() { + username1, _ := test.GenerateRandomString() + password1, _ := test.GenerateRandomString() + username2, _ := test.GenerateRandomString() + password2, _ := test.GenerateRandomString() + + htpasswdPath1 := test.MakeHtpasswdFileFromString(test.GetCredString(username1, password1)) + htpasswdPath2 := test.MakeHtpasswdFileFromString(test.GetCredString(username2, password2)) + + defer os.Remove(htpasswdPath1) + defer os.Remove(htpasswdPath2) + + htp := api.NewHTPasswd(logger) + htw, err := api.NewHTPasswdWatcher(htp, "") + So(err, ShouldBeNil) + + // Test ChangeFile() when not running + err = htw.ChangeFile(htpasswdPath1) + So(err, ShouldBeNil) + ok, present := htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Start watcher and test ChangeFile() when running + htw.Run() + defer htw.Close() + time.Sleep(10 * time.Millisecond) + + // Change to second file + err = htw.ChangeFile(htpasswdPath2) + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + ok, present = htp.Authenticate(username2, password2) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + _, present = htp.Authenticate(username1, password1) + So(present, ShouldBeFalse) + + // Test ChangeFile() to empty string (clear store) + err = htw.ChangeFile("") + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + _, present = htp.Authenticate(username2, password2) + So(present, ShouldBeFalse) + + // Test ChangeFile() with non-existent file + err = htw.ChangeFile("/non/existent/path") + So(err, ShouldNotBeNil) + + // Test file change detection and reload + err = htw.ChangeFile(htpasswdPath1) + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + ok, present = htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Change file content and verify automatic reload + err = os.WriteFile(htpasswdPath1, []byte(test.GetCredString(username1, password2)), 0o600) + So(err, ShouldBeNil) + time.Sleep(100 * time.Millisecond) + ok, present = htp.Authenticate(username1, password2) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Test multiple users + multiUserContent := test.GetCredString(username1, password1) + "\n" + test.GetCredString(username2, password2) + err = os.WriteFile(htpasswdPath1, []byte(multiUserContent), 0o600) + So(err, ShouldBeNil) + time.Sleep(100 * time.Millisecond) + ok, present = htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + ok, present = htp.Authenticate(username2, password2) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Test invalid content (clears store) + err = os.WriteFile(htpasswdPath1, []byte("invalid-content"), 0o600) + So(err, ShouldBeNil) + time.Sleep(100 * time.Millisecond) + _, present = htp.Authenticate(username1, password1) + So(present, ShouldBeFalse) + + // Test empty file (clears store) + err = os.WriteFile(htpasswdPath1, []byte(""), 0o600) + So(err, ShouldBeNil) + time.Sleep(100 * time.Millisecond) + _, present = htp.Authenticate(username2, password2) + So(present, ShouldBeFalse) + }) + + Convey("Test restart capability, edge cases, and file operations", func() { + // Create a buffer to capture log output + logBuffer, multiWriter := test.CreateLogCapturingWriter(os.Stdout) + capturingLogger := log.NewLoggerWithWriter("debug", multiWriter) + + username1, _ := test.GenerateRandomString() + password1, _ := test.GenerateRandomString() + username2, _ := test.GenerateRandomString() + password2, _ := test.GenerateRandomString() + + htpasswdPath1 := test.MakeHtpasswdFileFromString(test.GetCredString(username1, password1)) + htpasswdPath2 := test.MakeHtpasswdFileFromString(test.GetCredString(username2, password2)) + + defer os.Remove(htpasswdPath1) + defer os.Remove(htpasswdPath2) + + htp := api.NewHTPasswd(capturingLogger) + htw, err := api.NewHTPasswdWatcher(htp, htpasswdPath1) + So(err, ShouldBeNil) + + // Test restart capability + htw.Run() + time.Sleep(10 * time.Millisecond) + err = htw.ChangeFile(htpasswdPath1) + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + ok, present := htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Close and restart + So(htw.Close(), ShouldBeNil) + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + htw.Run() + time.Sleep(10 * time.Millisecond) + + // Change file after restart + err = htw.ChangeFile(htpasswdPath2) + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + ok, present = htp.Authenticate(username2, password2) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Test file becomes inaccessible + os.Remove(htpasswdPath2) + time.Sleep(100 * time.Millisecond) + ok, present = htp.Authenticate(username2, password2) + So(ok, ShouldBeTrue) // User should still be present + So(present, ShouldBeTrue) + + // Test file rename (should not trigger reload) + htpasswdPath3 := test.MakeHtpasswdFileFromString(test.GetCredString(username1, password1)) + defer os.Remove(htpasswdPath3) + err = htw.ChangeFile(htpasswdPath3) + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + ok, present = htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + newPath := htpasswdPath3 + ".new" + err = os.Rename(htpasswdPath3, newPath) + So(err, ShouldBeNil) + + defer os.Remove(newPath) + time.Sleep(100 * time.Millisecond) + ok, _ = htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) // User should still be present + + // Test file permission change (should not trigger reload) + err = os.Chmod(newPath, 0o000) + So(err, ShouldBeNil) + + defer func() { _ = os.Chmod(newPath, 0o644) }() + time.Sleep(100 * time.Millisecond) + ok, _ = htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) // User should still be present + + // Test with non-existent directory + htw2, err := api.NewHTPasswdWatcher(htp, "/non/existent/dir/htpasswd") + So(err, ShouldBeNil) + So(func() { htw2.Run() }, ShouldNotPanic) + time.Sleep(10 * time.Millisecond) + So(htw2.Close(), ShouldBeNil) + // 1 termination message + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Test with very long file path + longPath := "/tmp/" + for i := 0; i < 100; i++ { + longPath += "verylongdirname" + } + longPath += "/htpasswd" + htw3, err := api.NewHTPasswdWatcher(htp, longPath) + So(err, ShouldBeNil) + So(func() { htw3.Run() }, ShouldNotPanic) + time.Sleep(10 * time.Millisecond) + So(htw3.Close(), ShouldBeNil) + // 1 termination message + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Clean up + So(htw.Close(), ShouldBeNil) + // 1 termination message + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + }) + + Convey("Test concurrent operations and goroutine cleanup", func() { + // Create a buffer to capture log output + logBuffer, multiWriter := test.CreateLogCapturingWriter(os.Stdout) + capturingLogger := log.NewLoggerWithWriter("debug", multiWriter) + + username1, _ := test.GenerateRandomString() + password1, _ := test.GenerateRandomString() + username2, _ := test.GenerateRandomString() + password2, _ := test.GenerateRandomString() + + htpasswdPath1 := test.MakeHtpasswdFileFromString(test.GetCredString(username1, password1)) + htpasswdPath2 := test.MakeHtpasswdFileFromString(test.GetCredString(username2, password2)) + + defer os.Remove(htpasswdPath1) + defer os.Remove(htpasswdPath2) + + htp := api.NewHTPasswd(capturingLogger) + htw, err := api.NewHTPasswdWatcher(htp, "") + So(err, ShouldBeNil) + + // Test concurrent Run() and Close() + go func() { + for i := 0; i < 5; i++ { + htw.Run() + time.Sleep(1 * time.Millisecond) + } + }() + + go func() { + for i := 0; i < 5; i++ { + htw.Close() + time.Sleep(1 * time.Millisecond) + } + }() + + time.Sleep(50 * time.Millisecond) + So(func() { htw.Close() }, ShouldNotPanic) + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Test concurrent ChangeFile() operations + htw.Run() + defer htw.Close() + + go func() { + for i := 0; i < 3; i++ { + _ = htw.ChangeFile(htpasswdPath1) + + time.Sleep(1 * time.Millisecond) + } + }() + + go func() { + for i := 0; i < 3; i++ { + _ = htw.ChangeFile(htpasswdPath2) + + time.Sleep(1 * time.Millisecond) + } + }() + + time.Sleep(50 * time.Millisecond) + + // At least one user should be present + ok1, present1 := htp.Authenticate(username1, password1) + ok2, present2 := htp.Authenticate(username2, password2) + So(present1 || present2, ShouldBeTrue) + So(ok1 || ok2, ShouldBeTrue) + + // Test goroutine cleanup with multiple verification methods + htw2, err := api.NewHTPasswdWatcher(htp, "") + So(err, ShouldBeNil) + + // Start watcher + htw2.Run() + time.Sleep(10 * time.Millisecond) + + // Close watcher + So(htw2.Close(), ShouldBeNil) + + // Wait for goroutine to terminate (check log messages) + // 1 termination message + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Verify we can restart the watcher (indicates proper cleanup) + htw2.Run() + time.Sleep(10 * time.Millisecond) + So(htw2.Close(), ShouldBeNil) + // 1 termination message + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Test multiple Run/Close cycles + for i := 0; i < 3; i++ { + htw2.Run() + time.Sleep(10 * time.Millisecond) + So(htw2.Close(), ShouldBeNil) + time.Sleep(50 * time.Millisecond) // Give time for termination + } + }) + + Convey("Test goroutine termination with comprehensive log verification", func() { + // Create a buffer to capture log output + logBuffer, multiWriter := test.CreateLogCapturingWriter(os.Stdout) + capturingLogger := log.NewLoggerWithWriter("debug", multiWriter) + + // Test 1: Basic termination verification (no file watching) + htp1 := api.NewHTPasswd(capturingLogger) + htw1, err := api.NewHTPasswdWatcher(htp1, "") + So(err, ShouldBeNil) + + htw1.Run() + time.Sleep(10 * time.Millisecond) + So(htw1.Close(), ShouldBeNil) + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Test 2: File watching with fsnotify resources cleanup + username, _ := test.GenerateRandomString() + password, _ := test.GenerateRandomString() + htpasswdPath := test.MakeHtpasswdFileFromString(test.GetCredString(username, password)) + + defer os.Remove(htpasswdPath) + + htp2 := api.NewHTPasswd(capturingLogger) + htw2, err := api.NewHTPasswdWatcher(htp2, htpasswdPath) + So(err, ShouldBeNil) + + // Start watcher with file + htw2.Run() + time.Sleep(10 * time.Millisecond) + + // Load file to ensure watcher is active + err = htw2.ChangeFile(htpasswdPath) + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + + // Close watcher and verify termination + So(htw2.Close(), ShouldBeNil) + // 1 + 1 = 2 + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 2, 5*time.Second), ShouldBeTrue) + + // Test 3: Multiple termination cycles with file watching + for i := 0; i < 3; i++ { + htw2.Run() + time.Sleep(10 * time.Millisecond) + So(htw2.Close(), ShouldBeNil) + time.Sleep(50 * time.Millisecond) // Give time for termination + } + + // Verify we have at least 3 termination messages so far (2 previous + 1 cycle = 3) + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 3, 5*time.Second), ShouldBeTrue) + + // Test 4: Stress test with rapid cycles + for i := 0; i < 5; i++ { + htw2.Run() + time.Sleep(5 * time.Millisecond) + So(htw2.Close(), ShouldBeNil) + time.Sleep(20 * time.Millisecond) // Give time for termination + } + + // Verify we have at least 8 termination messages so far (3+5 = 8) + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 8, 5*time.Second), ShouldBeTrue) + + // Final verification: watcher should still work after all cycles + htw2.Run() + time.Sleep(10 * time.Millisecond) + So(htw2.Close(), ShouldBeNil) + + // Final verification of all termination messages with timeout + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 9, 5*time.Second), ShouldBeTrue) // 8+1 = 9 + }) + + Convey("Test malformed htpasswd files", func() { + // Create a buffer to capture log output + logBuffer, multiWriter := test.CreateLogCapturingWriter(os.Stdout) + capturingLogger := log.NewLoggerWithWriter("debug", multiWriter) + + username, _ := test.GenerateRandomString() + password, _ := test.GenerateRandomString() + + htp := api.NewHTPasswd(capturingLogger) + + // Test file with only colons (malformed) + colonPath := test.MakeHtpasswdFileFromString(":::") + defer os.Remove(colonPath) + htw1, err := api.NewHTPasswdWatcher(htp, colonPath) + So(err, ShouldBeNil) + htw1.Run() + time.Sleep(10 * time.Millisecond) + _ = htw1.ChangeFile(colonPath) + + time.Sleep(10 * time.Millisecond) + // The malformed file creates an entry with empty username, so test that + _, present := htp.Authenticate("", "anything") + So(present, ShouldBeTrue) // Empty username entry exists but auth fails + ok, _ := htp.Authenticate("", "anything") + So(ok, ShouldBeFalse) // But authentication should fail + So(htw1.Close(), ShouldBeNil) + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Test file with empty lines and comments + content := "\n\n" + test.GetCredString(username, password) + "\n# comment\n" + commentedPath := test.MakeHtpasswdFileFromString(content) + + defer os.Remove(commentedPath) + htw2, err := api.NewHTPasswdWatcher(htp, commentedPath) + So(err, ShouldBeNil) + htw2.Run() + time.Sleep(10 * time.Millisecond) + _ = htw2.ChangeFile(commentedPath) + + time.Sleep(10 * time.Millisecond) + ok, _ = htp.Authenticate(username, password) + So(ok, ShouldBeTrue) // User should be loaded (comments/empty lines ignored) + So(htw2.Close(), ShouldBeNil) + // 1 termination message + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + }) + + Convey("Test ChangeFile with nil watcher and empty filepath", func() { + // Create a logger (no need for log capture since we're not testing goroutine termination) + capturingLogger := log.NewLogger("debug", "") + + username, _ := test.GenerateRandomString() + password, _ := test.GenerateRandomString() + + htp := api.NewHTPasswd(capturingLogger) + htw, err := api.NewHTPasswdWatcher(htp, "") + So(err, ShouldBeNil) + + // Load some initial data + htpasswdPath := test.MakeHtpasswdFileFromString(test.GetCredString(username, password)) + defer os.Remove(htpasswdPath) + + // Load initial file (this will populate the store) + err = htw.ChangeFile(htpasswdPath) + So(err, ShouldBeNil) + + // Verify user is loaded + ok, present := htp.Authenticate(username, password) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Now test the edge case: ChangeFile with empty string when watcher is nil + // (watcher is nil because we haven't called Run() yet) + err = htw.ChangeFile("") + So(err, ShouldBeNil) // Should not return an error + + // Verify that the store was cleared + ok, present = htp.Authenticate(username, password) + So(ok, ShouldBeFalse) // Authentication should fail + So(present, ShouldBeFalse) // User should not be present + + // Test that we can still load a file after clearing + err = htw.ChangeFile(htpasswdPath) + So(err, ShouldBeNil) + + // Verify user is loaded again + ok, present = htp.Authenticate(username, password) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + }) + }) +} diff --git a/pkg/api/proxy.go b/pkg/api/proxy.go index 407231ab..5d2ff7e6 100644 --- a/pkg/api/proxy.go +++ b/pkg/api/proxy.go @@ -22,11 +22,12 @@ import ( func ClusterProxy(ctrlr *Controller) func(http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - config := ctrlr.Config + // Get cluster config safely + clusterConfig := ctrlr.Config.CopyClusterConfig() logger := ctrlr.Log // if no cluster or single-node cluster, handle locally. - if config.Cluster == nil || len(config.Cluster.Members) == 1 { + if !clusterConfig.IsClustered() { next.ServeHTTP(response, request) return @@ -45,13 +46,13 @@ func ClusterProxy(ctrlr *Controller) func(http.HandlerFunc) http.HandlerFunc { // the target member is the only one which should do read/write for the dist-spec APIs // for the given repository. - targetMemberIndex, targetMember := cluster.ComputeTargetMember(config.Cluster.HashKey, config.Cluster.Members, name) + targetMemberIndex, targetMember := cluster.ComputeTargetMember(clusterConfig.HashKey, clusterConfig.Members, name) logger.Debug().Str(constants.RepositoryLogKey, name). Msg(fmt.Sprintf("target member socket: %s index: %d", targetMember, targetMemberIndex)) // if the target member is the same as the local member, the current member should handle the request. // since the instances have the same config, a quick index lookup is sufficient - if targetMemberIndex == config.Cluster.Proxy.LocalMemberClusterSocketIndex { + if targetMemberIndex == clusterConfig.Proxy.LocalMemberClusterSocketIndex { logger.Debug().Str(constants.RepositoryLogKey, name).Msg("handling the request locally") next.ServeHTTP(response, request) @@ -119,8 +120,11 @@ func proxyHTTPRequest(ctx context.Context, req *http.Request, ) (*http.Response, error) { cloneURL := *req.URL + // Get HTTP TLS config safely + httpTLSConfig := ctrlr.Config.CopyTLSConfig() + proxyQueryScheme := "http" - if ctrlr.Config.HTTP.TLS != nil { + if httpTLSConfig != nil { proxyQueryScheme = "https" } @@ -142,12 +146,15 @@ func proxyHTTPRequest(ctx context.Context, req *http.Request, fwdRequest.Header.Set(constants.ScaleOutHopCountHeader, "1") clientOpts := common.HTTPClientOptions{ - TLSEnabled: ctrlr.Config.HTTP.TLS != nil, - VerifyTLS: ctrlr.Config.HTTP.TLS != nil, // for now, always verify TLS when TLS mode is enabled + TLSEnabled: httpTLSConfig != nil, + VerifyTLS: httpTLSConfig != nil, // for now, always verify TLS when TLS mode is enabled Host: targetMember, } - tlsConfig := ctrlr.Config.Cluster.TLS + // Get cluster config safely + clusterConfig := ctrlr.Config.CopyClusterConfig() + tlsConfig := clusterConfig.TLS + if tlsConfig != nil { clientOpts.CertOptions.ClientCertFile = tlsConfig.Cert clientOpts.CertOptions.ClientKeyFile = tlsConfig.Key diff --git a/pkg/api/routes.go b/pkg/api/routes.go index a0d24ad8..1655e634 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -73,9 +73,13 @@ func (rh *RouteHandler) SetupRoutes() { // first get Auth middleware in order to first setup openid/ldap/htpasswd, before oidc provider routes are setup authHandler := AuthHandler(rh.c) - applyCORSHeaders := getCORSHeadersHandler(rh.c.Config.HTTP.AllowOrigin) + // Get CORS config safely + allowOrigin := rh.c.Config.GetAllowOrigin() + applyCORSHeaders := getCORSHeadersHandler(allowOrigin) - if rh.c.Config.IsOpenIDAuthEnabled() { + // Get auth config for OpenID checks + authConfig := rh.c.Config.CopyAuthConfig() + if authConfig.IsOpenIDAuthEnabled() { // login path for openID rh.c.Router.HandleFunc(constants.LoginPath, rh.AuthURLHandler()) @@ -91,14 +95,15 @@ func (rh *RouteHandler) SetupRoutes() { } } - if rh.c.Config.IsAPIKeyEnabled() { + // Get auth config for API key checks + if authConfig.IsAPIKeyEnabled() { // enable api key management urls apiKeyRouter := rh.c.Router.PathPrefix(constants.APIKeyPath).Subrouter() apiKeyRouter.Use(authHandler) apiKeyRouter.Use(BaseAuthzHandler(rh.c)) // Always use CORSHeadersMiddleware before ACHeadersMiddleware - apiKeyRouter.Use(zcommon.CORSHeadersMiddleware(rh.c.Config.HTTP.AllowOrigin)) + apiKeyRouter.Use(zcommon.CORSHeadersMiddleware(rh.c.Config.GetAllowOrigin())) apiKeyRouter.Use(zcommon.ACHeadersMiddleware(rh.c.Config, http.MethodGet, http.MethodPost, http.MethodDelete, http.MethodOptions)) @@ -109,7 +114,7 @@ func (rh *RouteHandler) SetupRoutes() { /* on every route which may be used by UI we set OPTIONS as allowed METHOD to enable preflight request from UI to backend */ - if rh.c.Config.IsBasicAuthnEnabled() { + if authConfig.IsBasicAuthnEnabled() { // logout path for openID rh.c.Router.HandleFunc(constants.LogoutPath, getUIHeadersHandler(rh.c.Config, http.MethodPost, http.MethodOptions)(applyCORSHeaders(rh.Logout))). @@ -122,8 +127,9 @@ func (rh *RouteHandler) SetupRoutes() { prefixedDistSpecRouter := prefixedRouter.NewRoute().Subrouter() // authz is being enabled if AccessControl is specified // if Authn is not present AccessControl will have only default policies - if rh.c.Config.HTTP.AccessControl != nil { - if rh.c.Config.IsBasicAuthnEnabled() { + accessControlConfig := rh.c.Config.CopyAccessControlConfig() + if accessControlConfig != nil { + if authConfig.IsBasicAuthnEnabled() { rh.c.Log.Info().Msg("access control is being enabled") } else { rh.c.Log.Info().Msg("anonymous policy only access control is being enabled") @@ -241,7 +247,9 @@ func getUIHeadersHandler(config *config.Config, allowedMethods ...string) func(h response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) - if config.IsBasicAuthnEnabled() { + // Get auth config safely + authConfig := config.CopyAuthConfig() + if authConfig.IsBasicAuthnEnabled() { response.Header().Set("Access-Control-Allow-Credentials", "true") } @@ -267,13 +275,17 @@ func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, reques response.Header().Set(constants.DistAPIVersion, "registry/2.0") // NOTE: compatibility workaround - return this header in "allowed-read" mode to allow for clients to // work correctly - if rh.c.Config.IsBasicAuthnEnabled() || rh.c.Config.IsBearerAuthEnabled() { + // Get auth config safely + authConfig := rh.c.Config.CopyAuthConfig() + if authConfig.IsBasicAuthnEnabled() || authConfig.IsBearerAuthEnabled() { // don't send auth headers if request is coming from UI if request.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue { - if rh.c.Config.HTTP.Auth.Bearer != nil { - response.Header().Set("WWW-Authenticate", "bearer realm="+rh.c.Config.HTTP.Auth.Bearer.Realm) + if authConfig.Bearer != nil { + realm := authConfig.Bearer.Realm + response.Header().Set("WWW-Authenticate", "bearer realm="+realm) } else { - response.Header().Set("WWW-Authenticate", "basic realm="+rh.c.Config.HTTP.Realm) + realm := rh.c.Config.GetRealm() + response.Header().Set("WWW-Authenticate", "basic realm="+realm) } } } @@ -502,7 +514,9 @@ type ExtensionList struct { // @Failure 500 {string} string "internal server error" // @Router /v2/{name}/manifests/{reference} [get]. func (rh *RouteHandler) GetManifest(response http.ResponseWriter, request *http.Request) { - if rh.c.Config.IsBasicAuthnEnabled() { + // Get auth config safely + authConfig := rh.c.Config.CopyAuthConfig() + if authConfig.IsBasicAuthnEnabled() { response.Header().Set("Access-Control-Allow-Credentials", "true") } @@ -694,7 +708,9 @@ func (rh *RouteHandler) UpdateManifest(response http.ResponseWriter, request *ht } mediaType := request.Header.Get("Content-Type") - if !storageCommon.IsSupportedMediaType(rh.c.Config.HTTP.Compat, mediaType) { + compatConfig := rh.c.Config.GetCompat() + + if !storageCommon.IsSupportedMediaType(compatConfig, mediaType) { err := apiErr.NewError(apiErr.MANIFEST_INVALID).AddDetail(map[string]string{"mediaType": mediaType}) zcommon.WriteJSON(response, http.StatusUnsupportedMediaType, apiErr.NewErrorList(err)) @@ -949,7 +965,9 @@ func (rh *RouteHandler) CheckBlob(response http.ResponseWriter, request *http.Re } userCanMount := true - if rh.c.Config.IsAuthzEnabled() { + accessControlConfig := rh.c.Config.CopyAccessControlConfig() + + if accessControlConfig.IsAuthzEnabled() { userCanMount, err = canMount(userAc, imgStore, digest) if err != nil { rh.c.Log.Error().Err(err).Msg("unexpected error") @@ -1265,7 +1283,9 @@ func (rh *RouteHandler) CreateBlobUpload(response http.ResponseWriter, request * } userCanMount := true - if rh.c.Config.IsAuthzEnabled() { + accessControlConfig := rh.c.Config.CopyAccessControlConfig() + + if accessControlConfig.IsAuthzEnabled() { userCanMount, err = canMount(userAc, imgStore, mountDigest) if err != nil { rh.c.Log.Error().Err(err).Msg("unexpected error") @@ -2323,7 +2343,8 @@ func getBlobUploadLocation(url *url.URL, name string, digest godigest.Digest) st } func isSyncOnDemandEnabled(ctlr Controller) bool { - if ctlr.Config.IsSyncEnabled() && + extensionsConfig := ctlr.Config.CopyExtensionsConfig() + if extensionsConfig.IsSyncEnabled() && fmt.Sprintf("%v", ctlr.SyncOnDemand) != fmt.Sprintf("%v", nil) { return true } diff --git a/pkg/cli/server/config_reloader.go b/pkg/cli/server/config_reloader.go index 3ddb9aed..9121030d 100644 --- a/pkg/cli/server/config_reloader.go +++ b/pkg/cli/server/config_reloader.go @@ -85,9 +85,10 @@ func (hr *HotReloader) Start() { continue } - if hr.ctlr.Config.HTTP.Auth != nil && hr.ctlr.Config.HTTP.Auth.LDAP != nil && - hr.ctlr.Config.HTTP.Auth.LDAP.CredentialsFile != newConfig.HTTP.Auth.LDAP.CredentialsFile { - err = hr.watcher.Remove(hr.ctlr.Config.HTTP.Auth.LDAP.CredentialsFile) + authConfig := hr.ctlr.Config.CopyAuthConfig() + if authConfig.IsLdapAuthEnabled() && + authConfig.LDAP.CredentialsFile != newConfig.HTTP.Auth.LDAP.CredentialsFile { + err = hr.watcher.Remove(authConfig.LDAP.CredentialsFile) if err != nil && !errors.Is(err, fsnotify.ErrNonExistentWatch) { hr.logger.Error().Err(err).Msg("failed to remove old watch for the credentials file") } diff --git a/pkg/cli/server/extensions_test.go b/pkg/cli/server/extensions_test.go index 385a9107..ecb0798d 100644 --- a/pkg/cli/server/extensions_test.go +++ b/pkg/cli/server/extensions_test.go @@ -883,7 +883,7 @@ func TestServeScrubExtension(t *testing.T) { // Even if in config we specified scrub interval=1h, the minimum interval is 2h dataStr := string(data) - So(dataStr, ShouldContainSubstring, "\"Scrub\":{\"Enable\":true,\"Interval\":3600000000000}") + So(dataStr, ShouldContainSubstring, "\"Scrub\":{\"Enable\":true,\"Interval\":7200000000000}") So(dataStr, ShouldContainSubstring, "scrub interval set to too-short interval < 2h, changing scrub duration to 2 hours and continuing.") }) diff --git a/pkg/cli/server/root.go b/pkg/cli/server/root.go index 23263b33..a9b1316d 100644 --- a/pkg/cli/server/root.go +++ b/pkg/cli/server/root.go @@ -202,8 +202,9 @@ func NewServerRootCmd() *cobra.Command { logger := zlog.NewLogger("info", "") if showVersion { - logger.Info().Str("distribution-spec", distspec.Version).Str("commit", config.Commit). - Str("binary-type", config.BinaryType).Str("go version", config.GoVersion).Msg("version") + commit, binaryType, goVersion, _ := conf.GetVersionInfo() + logger.Info().Str("distribution-spec", distspec.Version).Str("commit", commit). + Str("binary-type", binaryType).Str("go version", goVersion).Msg("version") } else { _ = cmd.Usage() cmd.SilenceErrors = false @@ -228,19 +229,20 @@ func NewServerRootCmd() *cobra.Command { func validateStorageConfig(cfg *config.Config, logger zlog.Logger) error { expConfigMap := make(map[string]config.StorageConfig, 0) - defaultRootDir := cfg.Storage.RootDirectory + storageConfig := cfg.CopyStorageConfig() + defaultRootDir := storageConfig.RootDirectory - for _, storageConfig := range cfg.Storage.SubPaths { - if strings.EqualFold(defaultRootDir, storageConfig.RootDirectory) { + for _, subStorageConfig := range storageConfig.SubPaths { + if strings.EqualFold(defaultRootDir, subStorageConfig.RootDirectory) { msg := "invalid storage config, storage subpaths cannot use default storage root directory" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } - expConfig, ok := expConfigMap[storageConfig.RootDirectory] + expConfig, ok := expConfigMap[subStorageConfig.RootDirectory] if ok { - equal := expConfig.ParamsEqual(storageConfig) + equal := expConfig.ParamsEqual(subStorageConfig) if !equal { msg := "invalid storage config, storage config with same root directory should have same parameters" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -248,7 +250,7 @@ func validateStorageConfig(cfg *config.Config, logger zlog.Logger) error { return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } } else { - expConfigMap[storageConfig.RootDirectory] = storageConfig + expConfigMap[subStorageConfig.RootDirectory] = subStorageConfig } } @@ -257,20 +259,21 @@ func validateStorageConfig(cfg *config.Config, logger zlog.Logger) error { func validateCacheConfig(cfg *config.Config, logger zlog.Logger) error { // global + storageConfig := cfg.CopyStorageConfig() // dedupe true, remote storage, remoteCache true, but no cacheDriver (remote) //nolint: lll - if cfg.Storage.Dedupe && cfg.Storage.StorageDriver != nil && cfg.Storage.RemoteCache && cfg.Storage.CacheDriver == nil { + if storageConfig.Dedupe && storageConfig.StorageDriver != nil && storageConfig.RemoteCache && storageConfig.CacheDriver == nil { msg := "invalid database config, dedupe set to true with remote storage and database, but no remote database configured" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } - if cfg.Storage.CacheDriver != nil && cfg.Storage.RemoteCache { + if storageConfig.CacheDriver != nil && storageConfig.RemoteCache { // local storage with remote database // redis is supported with both local and S3 storage, while dynamodb is only supported with S3 // redis is only supported with local storage in a non-clustering scenario with a single zot instance, - if cfg.Storage.StorageDriver == nil && cfg.Storage.CacheDriver["name"] != storageConstants.RedisDriverName { + if storageConfig.StorageDriver == nil && storageConfig.CacheDriver["name"] != storageConstants.RedisDriverName { msg := "invalid database config, cannot have local storage driver with remote database!" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -278,23 +281,23 @@ func validateCacheConfig(cfg *config.Config, logger zlog.Logger) error { } // unsupported database driver - if cfg.Storage.CacheDriver["name"] != storageConstants.DynamoDBDriverName && - cfg.Storage.CacheDriver["name"] != storageConstants.RedisDriverName { + if storageConfig.CacheDriver["name"] != storageConstants.DynamoDBDriverName && + storageConfig.CacheDriver["name"] != storageConstants.RedisDriverName { msg := "invalid database config, unsupported database driver" - logger.Error().Err(zerr.ErrBadConfig).Interface("cacheDriver", cfg.Storage.CacheDriver["name"]).Msg(msg) + logger.Error().Err(zerr.ErrBadConfig).Interface("cacheDriver", storageConfig.CacheDriver["name"]).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } } - if !cfg.Storage.RemoteCache && cfg.Storage.CacheDriver != nil { - logger.Warn().Err(zerr.ErrBadConfig).Str("directory", cfg.Storage.RootDirectory). + if !storageConfig.RemoteCache && storageConfig.CacheDriver != nil { + logger.Warn().Err(zerr.ErrBadConfig).Str("directory", storageConfig.RootDirectory). Msg("invalid database config, remoteCache set to false but cacheDriver config (remote database)" + " provided for directory will ignore and use local database") } // subpaths - for _, subPath := range cfg.Storage.SubPaths { + for _, subPath := range storageConfig.SubPaths { // dedupe true, remote storage, remoteCache true, but no cacheDriver (remote) //nolint: lll if subPath.Dedupe && subPath.StorageDriver != nil && subPath.RemoteCache && subPath.CacheDriver == nil { @@ -316,14 +319,14 @@ func validateCacheConfig(cfg *config.Config, logger zlog.Logger) error { // unsupported cache driver if subPath.CacheDriver["name"] != storageConstants.DynamoDBDriverName { msg := "invalid database config, unsupported database driver" - logger.Error().Err(zerr.ErrBadConfig).Interface("cacheDriver", cfg.Storage.CacheDriver["name"]).Msg(msg) + logger.Error().Err(zerr.ErrBadConfig).Interface("cacheDriver", subPath.CacheDriver["name"]).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } } if !subPath.RemoteCache && subPath.CacheDriver != nil { - logger.Warn().Err(zerr.ErrBadConfig).Str("directory", cfg.Storage.RootDirectory). + logger.Warn().Err(zerr.ErrBadConfig).Str("directory", subPath.RootDirectory). Msg("invalid database config, remoteCache set to false but cacheDriver config (remote database)" + "provided for directory, will ignore and use local database") } @@ -335,11 +338,12 @@ func validateCacheConfig(cfg *config.Config, logger zlog.Logger) error { func validateRemoteSessionStoreConfig(cfg *config.Config, logger zlog.Logger) error { // it is okay for the session driver config to be nil // this is backwards compatible for older configs - if cfg.HTTP.Auth.SessionDriver == nil { + authConfig := cfg.CopyAuthConfig() + if authConfig == nil || authConfig.SessionDriver == nil { return nil } - sessionDriverName, ok := cfg.HTTP.Auth.SessionDriver["name"] + sessionDriverName, ok := authConfig.SessionDriver["name"] if !ok { msg := "must provide session driver name!" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -373,7 +377,7 @@ func validateRemoteSessionStoreConfig(cfg *config.Config, logger zlog.Logger) er // as redis session store does not support these yet. if sessionDriverName == storageConstants.RedisDriverName { - if cfg.HTTP.Auth.SessionKeysFile != "" { + if authConfig.SessionKeysFile != "" { msg := "session keys not supported when redis session driver is used!" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -385,19 +389,20 @@ func validateRemoteSessionStoreConfig(cfg *config.Config, logger zlog.Logger) er } func validateExtensionsConfig(cfg *config.Config, logger zlog.Logger) error { - if cfg.Extensions != nil && cfg.Extensions.Mgmt != nil { + extensionsConfig := cfg.CopyExtensionsConfig() + if extensionsConfig != nil && extensionsConfig.Mgmt != nil { logger.Warn().Msg("mgmt extensions configuration option has been made redundant and will be ignored.") } - if cfg.Extensions != nil && cfg.Extensions.APIKey != nil { + if extensionsConfig != nil && extensionsConfig.APIKey != nil { logger.Warn().Msg("apikey extension configuration will be ignored as API keys " + "are now configurable in the HTTP settings.") } - if cfg.Extensions != nil && cfg.Extensions.UI != nil && cfg.Extensions.UI.Enable != nil && *cfg.Extensions.UI.Enable { + if extensionsConfig.IsUIEnabled() { // it would make sense to also check for mgmt and user prefs to be enabled, // but those are both enabled by having the search and ui extensions enabled - if cfg.Extensions.Search == nil || !*cfg.Extensions.Search.Enable { + if !extensionsConfig.IsSearchEnabled() { msg := "failed to enable ui, search extension must be enabled" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -406,18 +411,17 @@ func validateExtensionsConfig(cfg *config.Config, logger zlog.Logger) error { } //nolint:lll - if cfg.Storage.StorageDriver != nil && cfg.Extensions != nil && cfg.Extensions.Search != nil && - cfg.Extensions.Search.Enable != nil && *cfg.Extensions.Search.Enable && cfg.Extensions.Search.CVE != nil { + storageConfig := cfg.CopyStorageConfig() + if storageConfig.StorageDriver != nil && extensionsConfig.IsCveScanningEnabled() { msg := "failed to enable cve scanning due to incompatibility with remote storage, please disable cve" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } - for _, subPath := range cfg.Storage.SubPaths { + for _, subPath := range storageConfig.SubPaths { //nolint:lll - if subPath.StorageDriver != nil && cfg.Extensions != nil && cfg.Extensions.Search != nil && - cfg.Extensions.Search.Enable != nil && *cfg.Extensions.Search.Enable && cfg.Extensions.Search.CVE != nil { + if subPath.StorageDriver != nil && extensionsConfig.IsCveScanningEnabled() { msg := "failed to enable cve scanning due to incompatibility with remote storage, please disable cve" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -466,24 +470,27 @@ func validateConfiguration(config *config.Config, logger zlog.Logger) error { } // check authorization config, it should have basic auth enabled or ldap, api keys or OpenID - if config.HTTP.AccessControl != nil { + accessControlConfig := config.CopyAccessControlConfig() + if accessControlConfig != nil { // checking for anonymous policy only authorization config: no users, no policies but anonymous policy if err := validateAuthzPolicies(config, logger); err != nil { return err } } - if len(config.Storage.StorageDriver) != 0 { + storageConfig := config.CopyStorageConfig() + if len(storageConfig.StorageDriver) != 0 { // enforce s3 driver in case of using storage driver - if config.Storage.StorageDriver["name"] != storageConstants.S3StorageDriverName { + if storageConfig.StorageDriver["name"] != storageConstants.S3StorageDriverName { msg := "unsupported storage driver" - logger.Error().Err(zerr.ErrBadConfig).Interface("cacheDriver", config.Storage.StorageDriver["name"]).Msg(msg) + logger.Error().Err(zerr.ErrBadConfig).Interface("cacheDriver", storageConfig.StorageDriver["name"]).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } // enforce tmpDir in case sync + s3 - if config.Extensions != nil && config.Extensions.Sync != nil && config.Extensions.Sync.DownloadDir == "" { + extensionsConfig := config.CopyExtensionsConfig() + if extensionsConfig.IsSyncEnabled() && extensionsConfig.Sync.DownloadDir == "" { msg := "using both sync and remote storage features needs config.Extensions.Sync.DownloadDir to be specified" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -492,22 +499,23 @@ func validateConfiguration(config *config.Config, logger zlog.Logger) error { } // enforce s3 driver on subpaths in case of using storage driver - if config.Storage.SubPaths != nil { - if len(config.Storage.SubPaths) > 0 { - subPaths := config.Storage.SubPaths + if storageConfig.SubPaths != nil { + if len(storageConfig.SubPaths) > 0 { + subPaths := storageConfig.SubPaths - for route, storageConfig := range subPaths { - if len(storageConfig.StorageDriver) != 0 { - if storageConfig.StorageDriver["name"] != storageConstants.S3StorageDriverName { + for route, subStorageConfig := range subPaths { + if len(subStorageConfig.StorageDriver) != 0 { + if subStorageConfig.StorageDriver["name"] != storageConstants.S3StorageDriverName { msg := "unsupported storage driver" logger.Error().Err(zerr.ErrBadConfig).Str("subpath", route).Interface("storageDriver", - storageConfig.StorageDriver["name"]).Msg(msg) + subStorageConfig.StorageDriver["name"]).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } // enforce tmpDir in case sync + s3 - if config.Extensions != nil && config.Extensions.Sync != nil && config.Extensions.Sync.DownloadDir == "" { + extensionsConfig := config.CopyExtensionsConfig() + if extensionsConfig.IsSyncEnabled() && extensionsConfig.Sync.DownloadDir == "" { msg := "using both sync and remote storage features needs config.Extensions.Sync.DownloadDir to be specified" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -519,8 +527,8 @@ func validateConfiguration(config *config.Config, logger zlog.Logger) error { } // check glob patterns in authz config are compilable - if config.HTTP.AccessControl != nil { - for pattern := range config.HTTP.AccessControl.Repositories { + if accessControlConfig != nil { + for pattern := range accessControlConfig.Repositories { ok := glob.ValidatePattern(pattern) if !ok { msg := "failed to compile authorization pattern" @@ -540,8 +548,10 @@ func validateConfiguration(config *config.Config, logger zlog.Logger) error { } func validateOpenIDConfig(cfg *config.Config, logger zlog.Logger) error { - if cfg.HTTP.Auth != nil && cfg.HTTP.Auth.OpenID != nil { - for provider, providerConfig := range cfg.HTTP.Auth.OpenID.Providers { + authConfig := cfg.CopyAuthConfig() + // can't check with IsOpenIDAuthEnabled(), because it can't test invalid providers + if authConfig != nil && authConfig.OpenID != nil && len(authConfig.OpenID.Providers) > 0 { + for provider, providerConfig := range authConfig.OpenID.Providers { //nolint: gocritic if config.IsOpenIDSupported(provider) { if providerConfig.ClientID == "" || providerConfig.Issuer == "" || @@ -571,8 +581,12 @@ func validateOpenIDConfig(cfg *config.Config, logger zlog.Logger) error { } func validateAuthzPolicies(config *config.Config, logger zlog.Logger) error { - if (config.HTTP.Auth == nil || (config.HTTP.Auth.HTPasswd.Path == "" && config.HTTP.Auth.LDAP == nil && - config.HTTP.Auth.OpenID == nil)) && !authzContainsOnlyAnonymousPolicy(config) { + authConfig := config.CopyAuthConfig() + accessControlConfig := config.CopyAccessControlConfig() + + logger.Info().Msg("checking if anonymous authorization is the only type of authorization policy configured") + + if !authConfig.IsBasicAuthnEnabled() && !accessControlConfig.ContainsOnlyAnonymousPolicy() { msg := "access control config requires one of httpasswd, ldap or openid authentication " + "or using only 'anonymousPolicy' policies" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -696,6 +710,15 @@ func applyDefaultValues(config *config.Config, viperInstance *viper.Viper, logge if config.Extensions.Scrub.Interval == 0 { config.Extensions.Scrub.Interval = 24 * time.Hour //nolint:mnd } + + // Validate minimum scrub interval + minScrubInterval, _ := time.ParseDuration("2h") + if config.Extensions.Scrub.Interval < minScrubInterval { + config.Extensions.Scrub.Interval = minScrubInterval + + logger.Warn().Msg("scrub interval set to too-short interval < 2h, " + + "changing scrub duration to 2 hours and continuing.") + } } if config.Extensions.UI != nil { @@ -1076,52 +1099,11 @@ func readSecretFile(path string, v any, checkUnsetFields bool) error { //nolint: return nil } -func authzContainsOnlyAnonymousPolicy(cfg *config.Config) bool { - logger := zlog.NewLogger("info", "") - - adminPolicy := cfg.HTTP.AccessControl.AdminPolicy - anonymousPolicyPresent := false - - logger.Info().Msg("checking if anonymous authorization is the only type of authorization policy configured") - - if len(adminPolicy.Actions)+len(adminPolicy.Users) > 0 { - logger.Info().Msg("admin policy detected, anonymous authorization is not the only authorization policy configured") - - return false - } - - for _, repository := range cfg.HTTP.AccessControl.Repositories { - if len(repository.DefaultPolicy) > 0 { - logger.Info().Interface("repository", repository). - Msg("default policy detected, anonymous authorization is not the only authorization policy configured") - - return false - } - - if len(repository.AnonymousPolicy) > 0 { - logger.Info().Msg("anonymous authorization detected") - - anonymousPolicyPresent = true - } - - for _, policy := range repository.Policies { - if len(policy.Actions)+len(policy.Users) > 0 { - logger.Info().Interface("repository", repository). - Msg("repository with non-empty policy detected, " + - "anonymous authorization is not the only authorization policy configured") - - return false - } - } - } - - return anonymousPolicyPresent -} - func validateLDAP(config *config.Config, logger zlog.Logger) error { // LDAP mandatory configuration - if config.HTTP.Auth != nil && config.HTTP.Auth.LDAP != nil { - ldap := config.HTTP.Auth.LDAP + authConfig := config.CopyAuthConfig() + if authConfig.IsLdapAuthEnabled() { + ldap := authConfig.LDAP if ldap.UserAttribute == "" { msg := "invalid LDAP configuration, missing mandatory key: userAttribute" logger.Error().Str("userAttribute", ldap.UserAttribute).Msg(msg) @@ -1148,12 +1130,13 @@ func validateLDAP(config *config.Config, logger zlog.Logger) error { } func validateHTTP(config *config.Config, logger zlog.Logger) error { - if config.HTTP.Port != "" { - port, err := strconv.ParseInt(config.HTTP.Port, 10, 64) - if err != nil || (port < 0 || port > 65535) { - logger.Error().Str("port", config.HTTP.Port).Msg("invalid port") + port := config.GetHTTPPort() + if port != "" { + portInt, err := strconv.ParseInt(port, 10, 64) + if err != nil || (portInt < 0 || portInt > 65535) { + logger.Error().Str("port", port).Msg("invalid port") - return fmt.Errorf("%w: invalid port %s", zerr.ErrBadConfig, config.HTTP.Port) + return fmt.Errorf("%w: invalid port %s", zerr.ErrBadConfig, port) } } @@ -1162,40 +1145,41 @@ func validateHTTP(config *config.Config, logger zlog.Logger) error { func validateGC(config *config.Config, logger zlog.Logger) error { // enforce GC params - if config.Storage.GCDelay < 0 { - logger.Error().Err(zerr.ErrBadConfig).Dur("delay", config.Storage.GCDelay). + storageConfig := config.CopyStorageConfig() + if storageConfig.GCDelay < 0 { + logger.Error().Err(zerr.ErrBadConfig).Dur("delay", storageConfig.GCDelay). Msg("invalid garbage-collect delay specified") return fmt.Errorf("%w: invalid garbage-collect delay specified %s", - zerr.ErrBadConfig, config.Storage.GCDelay) + zerr.ErrBadConfig, storageConfig.GCDelay) } - if config.Storage.GCInterval < 0 { - logger.Error().Err(zerr.ErrBadConfig).Dur("interval", config.Storage.GCInterval). + if storageConfig.GCInterval < 0 { + logger.Error().Err(zerr.ErrBadConfig).Dur("interval", storageConfig.GCInterval). Msg("invalid garbage-collect interval specified") return fmt.Errorf("%w: invalid garbage-collect interval specified %s", - zerr.ErrBadConfig, config.Storage.GCInterval) + zerr.ErrBadConfig, storageConfig.GCInterval) } - if !config.Storage.GC { - if config.Storage.GCDelay != 0 { + if !storageConfig.GC { + if storageConfig.GCDelay != 0 { logger.Warn().Err(zerr.ErrBadConfig). Msg("garbage-collect delay specified without enabling garbage-collect, will be ignored") } - if config.Storage.GCInterval != 0 { + if storageConfig.GCInterval != 0 { logger.Warn().Err(zerr.ErrBadConfig). Msg("periodic garbage-collect interval specified without enabling garbage-collect, will be ignored") } } - if err := validateGCRules(config.Storage.Retention, logger); err != nil { + if err := validateGCRules(storageConfig.Retention, logger); err != nil { return err } // subpaths - for name, subPath := range config.Storage.SubPaths { + for name, subPath := range storageConfig.SubPaths { if subPath.GC && subPath.GCDelay <= 0 { logger.Error().Err(zerr.ErrBadConfig). Str("subPath", name). @@ -1245,13 +1229,15 @@ func validateGCRules(retention config.ImageRetention, logger zlog.Logger) error func validateSync(config *config.Config, logger zlog.Logger) error { // check glob patterns in sync config are compilable - if config.Extensions != nil && config.Extensions.Sync != nil { - for regID, regCfg := range config.Extensions.Sync.Registries { + extensionsConfig := config.CopyExtensionsConfig() + // can't check with IsSyncEnabled(), because it can't test invalid sync configs + if extensionsConfig != nil && extensionsConfig.Sync != nil && len(extensionsConfig.Sync.Registries) > 0 { + for regID, regCfg := range extensionsConfig.Sync.Registries { // check retry options are configured for sync if regCfg.MaxRetries != nil && regCfg.RetryDelay == nil { msg := "retryDelay is required when using maxRetries" logger.Error().Err(zerr.ErrBadConfig).Int("id", regID).Interface("extensions.sync.registries[id]", - config.Extensions.Sync.Registries[regID]).Msg(msg) + extensionsConfig.Sync.Registries[regID]).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } @@ -1260,7 +1246,7 @@ func validateSync(config *config.Config, logger zlog.Logger) error { if regCfg.PreserveDigest && !config.IsCompatEnabled() { msg := "can not use PreserveDigest option without enabling http.Compat" logger.Error().Err(zerr.ErrBadConfig).Int("id", regID).Interface("extensions.sync.registries[id]", - config.Extensions.Sync.Registries[regID]).Msg(msg) + extensionsConfig.Sync.Registries[regID]).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } @@ -1314,8 +1300,9 @@ func validateSync(config *config.Config, logger zlog.Logger) error { } func validateClusterConfig(config *config.Config, logger zlog.Logger) error { - if config.Cluster != nil { - if len(config.Cluster.Members) == 0 { + clusterConfig := config.CopyClusterConfig() + if clusterConfig != nil { + if len(clusterConfig.Members) == 0 { msg := "cannot have 0 members in a scale out cluster" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -1325,10 +1312,10 @@ func validateClusterConfig(config *config.Config, logger zlog.Logger) error { // the allowed length is 16 as the siphash requires a 128 bit key. // that translates to 16 characters * 8 bits each. allowedHashKeyLength := 16 - if len(config.Cluster.HashKey) != allowedHashKeyLength { + if len(clusterConfig.HashKey) != allowedHashKeyLength { msg := fmt.Sprintf("hashKey for scale out cluster must have %d characters", allowedHashKeyLength) logger.Error().Err(zerr.ErrBadConfig). - Str("hashkey", config.Cluster.HashKey). + Str("hashkey", clusterConfig.HashKey). Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) diff --git a/pkg/common/http_server.go b/pkg/common/http_server.go index 69b57fca..e29edf0d 100644 --- a/pkg/common/http_server.go +++ b/pkg/common/http_server.go @@ -38,7 +38,9 @@ func ACHeadersMiddleware(config *config.Config, allowedMethods ...string) mux.Mi resp.Header().Set("Access-Control-Allow-Methods", allowedMethodsValue) resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) - if config.IsBasicAuthnEnabled() { + // Get auth config safely + authConfig := config.CopyAuthConfig() + if authConfig.IsBasicAuthnEnabled() { resp.Header().Set("Access-Control-Allow-Credentials", "true") } @@ -73,23 +75,28 @@ func AddCORSHeaders(allowOrigin string, response http.ResponseWriter) { func AuthzOnlyAdminsMiddleware(conf *config.Config) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - if !conf.IsBasicAuthnEnabled() { + // Get auth config safely + authConfig := conf.CopyAuthConfig() + if !authConfig.IsBasicAuthnEnabled() { next.ServeHTTP(response, request) return } + realm := conf.GetRealm() + failDelay := authConfig.GetFailDelay() + // get userAccessControl built in previous authn/authz middlewares userAc, err := reqCtx.UserAcFromContext(request.Context()) if err != nil { // should not happen as this has been previously checked for errors - AuthzFail(response, request, userAc.GetUsername(), conf.HTTP.Realm, conf.HTTP.Auth.FailDelay) + AuthzFail(response, request, userAc.GetUsername(), realm, failDelay) return } // reject non-admin access if authentication is enabled if userAc != nil && !userAc.IsAdmin() { - AuthzFail(response, request, userAc.GetUsername(), conf.HTTP.Realm, conf.HTTP.Auth.FailDelay) + AuthzFail(response, request, userAc.GetUsername(), realm, failDelay) return } diff --git a/pkg/extensions/config/config.go b/pkg/extensions/config/config.go index 8f139433..20bf8b7f 100644 --- a/pkg/extensions/config/config.go +++ b/pkg/extensions/config/config.go @@ -46,12 +46,11 @@ type LintConfig struct { type SearchConfig struct { BaseConfig `mapstructure:",squash"` - // CVE search - CVE *CVEConfig + CVE *CVEConfig } type CVEConfig struct { - UpdateInterval time.Duration // should be 2 hours or more, if not specified default be kept as 24 hours + UpdateInterval time.Duration // should be 2 hours or more, if not specified default be kept as 2 hours Trivy *TrivyConfig } @@ -77,3 +76,166 @@ type ScrubConfig struct { type UIConfig struct { BaseConfig `mapstructure:",squash"` } + +// isSearchEnabledInternal checks if search is enabled (internal use only). +func (e *ExtensionConfig) isSearchEnabledInternal() bool { + return e != nil && e.Search != nil && e.Search.Enable != nil && *e.Search.Enable +} + +// isUIEnabledInternal checks if UI is enabled (internal use only). +func (e *ExtensionConfig) isUIEnabledInternal() bool { + return e != nil && e.UI != nil && e.UI.Enable != nil && *e.UI.Enable +} + +// IsCveScanningEnabled checks if CVE scanning is enabled in this extensions config. +func (e *ExtensionConfig) IsCveScanningEnabled() bool { + if e == nil { + return false + } + + return e.Search != nil && e.Search.Enable != nil && *e.Search.Enable && + e.Search.CVE != nil && e.Search.CVE.Trivy != nil +} + +// IsEventRecorderEnabled checks if event recording is enabled in this extensions config. +func (e *ExtensionConfig) IsEventRecorderEnabled() bool { + if e == nil { + return false + } + + return e.Events != nil && e.Events.Enable != nil && *e.Events.Enable +} + +// IsSearchEnabled checks if search is enabled in this extensions config. +func (e *ExtensionConfig) IsSearchEnabled() bool { + return e.isSearchEnabledInternal() +} + +// IsSyncEnabled checks if sync is enabled in this extensions config. +func (e *ExtensionConfig) IsSyncEnabled() bool { + if e == nil { + return false + } + + // Sync is enabled if either: + // 1. Explicitly enabled (Enable == true), OR + // 2. There are registries configured (enabled by default when registries exist) + // This matches the behavior in root.go where Sync.Enable defaults to true when registries are present + return e.Sync != nil && ((e.Sync.Enable != nil && *e.Sync.Enable) || len(e.Sync.Registries) > 0) +} + +// IsScrubEnabled checks if scrub is enabled in this extensions config. +func (e *ExtensionConfig) IsScrubEnabled() bool { + if e == nil { + return false + } + + return e.Scrub != nil && e.Scrub.Enable != nil && *e.Scrub.Enable +} + +// IsMetricsEnabled checks if metrics are enabled in this extensions config. +func (e *ExtensionConfig) IsMetricsEnabled() bool { + if e == nil { + return false + } + + return e.Metrics != nil && e.Metrics.Enable != nil && *e.Metrics.Enable +} + +// IsCosignEnabled checks if Cosign is enabled in this extensions config. +func (e *ExtensionConfig) IsCosignEnabled() bool { + if e == nil { + return false + } + + return e.Trust != nil && e.Trust.Enable != nil && *e.Trust.Enable && e.Trust.Cosign +} + +// IsNotationEnabled checks if Notation is enabled in this extensions config. +func (e *ExtensionConfig) IsNotationEnabled() bool { + if e == nil { + return false + } + + return e.Trust != nil && e.Trust.Enable != nil && *e.Trust.Enable && e.Trust.Notation +} + +// IsImageTrustEnabled checks if image trust is enabled in this extensions config. +func (e *ExtensionConfig) IsImageTrustEnabled() bool { + if e == nil { + return false + } + + return e.Trust != nil && e.Trust.Enable != nil && *e.Trust.Enable +} + +// IsUIEnabled checks if UI is enabled in this extensions config. +func (e *ExtensionConfig) IsUIEnabled() bool { + return e.isUIEnabledInternal() +} + +// AreUserPrefsEnabled checks if user preferences are enabled in this extensions config. +func (e *ExtensionConfig) AreUserPrefsEnabled() bool { + if e == nil { + return false + } + + return e.isSearchEnabledInternal() && e.isUIEnabledInternal() +} + +// GetSearchCVEConfig returns the search CVE config. +func (e *ExtensionConfig) GetSearchCVEConfig() *CVEConfig { + if e == nil { + return nil + } + + if e.Search != nil { + return e.Search.CVE + } + + return nil +} + +// GetScrubInterval returns the scrub interval. +func (e *ExtensionConfig) GetScrubInterval() time.Duration { + if e == nil { + return 0 + } + + if e.Scrub != nil { + return e.Scrub.Interval + } + + return 0 +} + +// GetSyncConfig returns the sync config. +func (e *ExtensionConfig) GetSyncConfig() *sync.Config { + if e == nil { + return nil + } + + return e.Sync +} + +// GetMetricsPrometheusConfig returns the metrics prometheus config. +func (e *ExtensionConfig) GetMetricsPrometheusConfig() *PrometheusConfig { + if e == nil { + return nil + } + + if e.Metrics != nil { + return e.Metrics.Prometheus + } + + return nil +} + +// GetEventsConfig returns the events config. +func (e *ExtensionConfig) GetEventsConfig() *events.Config { + if e == nil { + return nil + } + + return e.Events +} diff --git a/pkg/extensions/config/config_test.go b/pkg/extensions/config/config_test.go new file mode 100644 index 00000000..a43069dd --- /dev/null +++ b/pkg/extensions/config/config_test.go @@ -0,0 +1,564 @@ +package config_test + +import ( + "errors" + "testing" + + . "github.com/smartystreets/goconvey/convey" + + "zotregistry.dev/zot/v2/pkg/extensions/config" + "zotregistry.dev/zot/v2/pkg/extensions/config/events" + "zotregistry.dev/zot/v2/pkg/extensions/config/sync" +) + +var ( + errIsSearchEnabledExpectedTrue = errors.New("expected IsSearchEnabled to return true, got false") + errIsUIEnabledExpectedTrue = errors.New("expected IsUIEnabled to return true, got false") + errAreUserPrefsEnabledExpectedTrue = errors.New("expected AreUserPrefsEnabled to return true, got false") + errPanicRecovered = errors.New("panic recovered") +) + +// Config builder functions for different extension types. +func buildSearchConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Search = &config.SearchConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + } + + return ext +} + +func buildSearchConfigWithCVE(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Search = &config.SearchConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + CVE: &config.CVEConfig{ + Trivy: &config.TrivyConfig{}, + }, + } + + return ext +} + +func buildEventsConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Events = &events.Config{ + Enable: &enabled, + } + + return ext +} + +func buildSyncConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Sync = &sync.Config{ + Enable: &enabled, + } + + return ext +} + +func buildScrubConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Scrub = &config.ScrubConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + } + + return ext +} + +func buildMetricsConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Metrics = &config.MetricsConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + Prometheus: &config.PrometheusConfig{ + Path: "/metrics", + }, + } + + return ext +} + +func buildTrustConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Trust = &config.ImageTrustConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + } + + return ext +} + +func buildUIConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.UI = &config.UIConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + } + + return ext +} + +func buildSearchAndUIConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Search = &config.SearchConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + } + ext.UI = &config.UIConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + } + + return ext +} + +func buildTrustConfigWithCosign(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Trust = &config.ImageTrustConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + Cosign: true, + } + + return ext +} + +func buildTrustConfigWithNotation(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Trust = &config.ImageTrustConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + Notation: true, + } + + return ext +} + +// Test helper functions to reduce code duplication + +// testMethodWithNilConfig tests a method with nil ExtensionConfig. +func testMethodWithNilConfig(testFunc func(*config.ExtensionConfig) bool) { + Convey("Test with nil ExtensionConfig", func() { + var extensionConfig *config.ExtensionConfig = nil + + So(testFunc(extensionConfig), ShouldBeFalse) + }) +} + +// testMethodWithNilSubConfig tests a method when ExtensionConfig exists but the relevant sub-config is nil. +func testMethodWithNilSubConfig(subConfigName string, testFunc func(*config.ExtensionConfig) bool) { + Convey("Test with ExtensionConfig but nil "+subConfigName, func() { + extensionConfig := &config.ExtensionConfig{} + + So(testFunc(extensionConfig), ShouldBeFalse) + }) +} + +// testMethodWithNilEnable tests a method when ExtensionConfig and sub-config exist but Enable is nil. +func testMethodWithNilEnable(subConfigName string, testFunc func(*config.ExtensionConfig) bool) { + Convey("Test with ExtensionConfig and "+subConfigName+" but nil Enable", func() { + extensionConfig := &config.ExtensionConfig{} + + So(testFunc(extensionConfig), ShouldBeFalse) + }) +} + +// testMethodWithDisabledEnable tests a method when Enable is explicitly set to false. +func testMethodWithDisabledEnable( + subConfigName string, + testFunc func(*config.ExtensionConfig) bool, + configBuilder func(bool) *config.ExtensionConfig, +) { + Convey("Test with ExtensionConfig and "+subConfigName+" and Enable but disabled", func() { + disabled := false + extensionConfig := configBuilder(disabled) + So(testFunc(extensionConfig), ShouldBeFalse) + }) +} + +// testMethodWithEnabledEnable tests a method when Enable is explicitly set to true. +func testMethodWithEnabledEnable( + subConfigName string, + testFunc func(*config.ExtensionConfig) bool, + configBuilder func(bool) *config.ExtensionConfig, +) { + Convey("Test with ExtensionConfig and "+subConfigName+" and Enable enabled", func() { + enabled := true + extensionConfig := configBuilder(enabled) + So(testFunc(extensionConfig), ShouldBeTrue) + }) +} + +// testConcurrentAccessWithConfig tests concurrent access to a method with a properly configured ExtensionConfig. +func testConcurrentAccessWithConfig( + methodName string, + testFunc func(*config.ExtensionConfig) bool, + expectedError error, + extensionConfig *config.ExtensionConfig, +) { + Convey("Test concurrent access to "+methodName, func() { + // Test concurrent access to verify thread-safety + done := make(chan bool, 10) + errors := make(chan error, 10) + + for i := 0; i < 10; i++ { + go func() { + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + errors <- err + } else { + errors <- errPanicRecovered + } + } + done <- true + }() + + for j := 0; j < 100; j++ { + result := testFunc(extensionConfig) + if !result { + errors <- expectedError + + return + } + } + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + + // Check for errors + close(errors) + + for err := range errors { + So(err, ShouldBeNil) + } + }) +} + +// testGetterWithNilConfig tests a getter method with nil ExtensionConfig. +func testGetterWithNilConfig[T any](testFunc func(*config.ExtensionConfig) T, expected T) { + Convey("Test with nil ExtensionConfig", func() { + var extensionConfig *config.ExtensionConfig = nil + + result := testFunc(extensionConfig) + So(result, ShouldEqual, expected) + }) +} + +// testGetterWithNilSubConfig tests a getter method when ExtensionConfig exists but the relevant sub-config is nil. +func testGetterWithNilSubConfig[T any](subConfigName string, testFunc func(*config.ExtensionConfig) T, expected T) { + Convey("Test with ExtensionConfig but nil "+subConfigName, func() { + extensionConfig := &config.ExtensionConfig{} + + result := testFunc(extensionConfig) + So(result, ShouldEqual, expected) + }) +} + +// testGetterWithValidConfig tests a getter method with valid configuration. +func testGetterWithValidConfig[T any]( + subConfigName string, + testFunc func(*config.ExtensionConfig) T, + configBuilder func(bool) *config.ExtensionConfig, +) { + Convey("Test with valid "+subConfigName+" configuration", func() { + enabled := true + extensionConfig := configBuilder(enabled) + + result := testFunc(extensionConfig) + So(result, ShouldNotBeNil) + }) +} + +func TestExtensionConfig(t *testing.T) { + Convey("Test public methods", t, func() { + Convey("Test IsCveScanningEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsCveScanningEnabled) + testMethodWithNilSubConfig("Search", (*config.ExtensionConfig).IsCveScanningEnabled) + testMethodWithNilEnable("Search", (*config.ExtensionConfig).IsCveScanningEnabled) + testMethodWithDisabledEnable("Search", (*config.ExtensionConfig).IsCveScanningEnabled, buildSearchConfig) + testMethodWithEnabledEnable("Search", (*config.ExtensionConfig).IsCveScanningEnabled, buildSearchConfigWithCVE) + }) + + Convey("Test IsEventRecorderEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsEventRecorderEnabled) + testMethodWithNilSubConfig("Events", (*config.ExtensionConfig).IsEventRecorderEnabled) + testMethodWithNilEnable("Events", (*config.ExtensionConfig).IsEventRecorderEnabled) + testMethodWithDisabledEnable("Events", (*config.ExtensionConfig).IsEventRecorderEnabled, buildEventsConfig) + testMethodWithEnabledEnable("Events", (*config.ExtensionConfig).IsEventRecorderEnabled, buildEventsConfig) + }) + + Convey("Test IsSearchEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsSearchEnabled) + testMethodWithNilSubConfig("Search", (*config.ExtensionConfig).IsSearchEnabled) + testMethodWithNilEnable("Search", (*config.ExtensionConfig).IsSearchEnabled) + testMethodWithDisabledEnable("Search", (*config.ExtensionConfig).IsSearchEnabled, buildSearchConfig) + testMethodWithEnabledEnable("Search", (*config.ExtensionConfig).IsSearchEnabled, buildSearchConfig) + }) + + Convey("Test IsSyncEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsSyncEnabled) + testMethodWithNilSubConfig("Sync", (*config.ExtensionConfig).IsSyncEnabled) + testMethodWithNilEnable("Sync", (*config.ExtensionConfig).IsSyncEnabled) + testMethodWithDisabledEnable("Sync", (*config.ExtensionConfig).IsSyncEnabled, buildSyncConfig) + testMethodWithEnabledEnable("Sync", (*config.ExtensionConfig).IsSyncEnabled, buildSyncConfig) + }) + + Convey("Test IsScrubEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsScrubEnabled) + testMethodWithNilSubConfig("Scrub", (*config.ExtensionConfig).IsScrubEnabled) + testMethodWithNilEnable("Scrub", (*config.ExtensionConfig).IsScrubEnabled) + testMethodWithDisabledEnable("Scrub", (*config.ExtensionConfig).IsScrubEnabled, buildScrubConfig) + testMethodWithEnabledEnable("Scrub", (*config.ExtensionConfig).IsScrubEnabled, buildScrubConfig) + }) + + Convey("Test IsMetricsEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsMetricsEnabled) + testMethodWithNilSubConfig("Metrics", (*config.ExtensionConfig).IsMetricsEnabled) + testMethodWithNilEnable("Metrics", (*config.ExtensionConfig).IsMetricsEnabled) + testMethodWithDisabledEnable("Metrics", (*config.ExtensionConfig).IsMetricsEnabled, buildMetricsConfig) + testMethodWithEnabledEnable("Metrics", (*config.ExtensionConfig).IsMetricsEnabled, buildMetricsConfig) + }) + + Convey("Test IsCosignEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsCosignEnabled) + testMethodWithNilSubConfig("Trust", (*config.ExtensionConfig).IsCosignEnabled) + testMethodWithNilEnable("Trust", (*config.ExtensionConfig).IsCosignEnabled) + testMethodWithDisabledEnable("Trust", (*config.ExtensionConfig).IsCosignEnabled, buildTrustConfig) + testMethodWithEnabledEnable("Trust", (*config.ExtensionConfig).IsCosignEnabled, buildTrustConfigWithCosign) + }) + + Convey("Test IsNotationEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsNotationEnabled) + testMethodWithNilSubConfig("Trust", (*config.ExtensionConfig).IsNotationEnabled) + testMethodWithNilEnable("Trust", (*config.ExtensionConfig).IsNotationEnabled) + testMethodWithDisabledEnable("Trust", (*config.ExtensionConfig).IsNotationEnabled, buildTrustConfig) + testMethodWithEnabledEnable("Trust", (*config.ExtensionConfig).IsNotationEnabled, buildTrustConfigWithNotation) + }) + + Convey("Test IsImageTrustEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsImageTrustEnabled) + testMethodWithNilSubConfig("Trust", (*config.ExtensionConfig).IsImageTrustEnabled) + testMethodWithNilEnable("Trust", (*config.ExtensionConfig).IsImageTrustEnabled) + testMethodWithDisabledEnable("Trust", (*config.ExtensionConfig).IsImageTrustEnabled, buildTrustConfig) + testMethodWithEnabledEnable("Trust", (*config.ExtensionConfig).IsImageTrustEnabled, buildTrustConfig) + }) + + Convey("Test IsUIEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsUIEnabled) + testMethodWithNilSubConfig("UI", (*config.ExtensionConfig).IsUIEnabled) + testMethodWithNilEnable("UI", (*config.ExtensionConfig).IsUIEnabled) + testMethodWithDisabledEnable("UI", (*config.ExtensionConfig).IsUIEnabled, buildUIConfig) + testMethodWithEnabledEnable("UI", (*config.ExtensionConfig).IsUIEnabled, buildUIConfig) + }) + + Convey("Test AreUserPrefsEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).AreUserPrefsEnabled) + testMethodWithNilSubConfig("Search", (*config.ExtensionConfig).AreUserPrefsEnabled) + testMethodWithNilEnable("UI", (*config.ExtensionConfig).AreUserPrefsEnabled) + testMethodWithDisabledEnable("Search", (*config.ExtensionConfig).AreUserPrefsEnabled, buildSearchConfig) + testMethodWithEnabledEnable("Search", (*config.ExtensionConfig).AreUserPrefsEnabled, buildSearchAndUIConfig) + }) + }) + + // Additional tests to verify thread-safety and internal method behavior + Convey("Test thread-safety and internal method coverage", t, func() { + // Create properly configured ExtensionConfigs for concurrent testing + searchEnabled := true + uiEnabled := true + searchConfig := &config.ExtensionConfig{} + searchConfig.Search = &config.SearchConfig{ + BaseConfig: config.BaseConfig{ + Enable: &searchEnabled, + }, + } + uiConfig := &config.ExtensionConfig{} + uiConfig.UI = &config.UIConfig{ + BaseConfig: config.BaseConfig{ + Enable: &uiEnabled, + }, + } + searchAndUIConfig := &config.ExtensionConfig{} + searchAndUIConfig.Search = &config.SearchConfig{ + BaseConfig: config.BaseConfig{ + Enable: &searchEnabled, + }, + } + searchAndUIConfig.UI = &config.UIConfig{ + BaseConfig: config.BaseConfig{ + Enable: &uiEnabled, + }, + } + + testConcurrentAccessWithConfig( + "IsSearchEnabled", + (*config.ExtensionConfig).IsSearchEnabled, + errIsSearchEnabledExpectedTrue, + searchConfig, + ) + testConcurrentAccessWithConfig( + "IsUIEnabled", + (*config.ExtensionConfig).IsUIEnabled, + errIsUIEnabledExpectedTrue, + uiConfig, + ) + testConcurrentAccessWithConfig( + "AreUserPrefsEnabled", + (*config.ExtensionConfig).AreUserPrefsEnabled, + errAreUserPrefsEnabledExpectedTrue, + searchAndUIConfig, + ) + + Convey("Test mixed concurrent access to all methods", func() { + searchEnabled := true + uiEnabled := true + extensionConfig := &config.ExtensionConfig{} + extensionConfig.Search = &config.SearchConfig{ + BaseConfig: config.BaseConfig{ + Enable: &searchEnabled, + }, + } + extensionConfig.UI = &config.UIConfig{ + BaseConfig: config.BaseConfig{ + Enable: &uiEnabled, + }, + } + + // Test mixed concurrent access to verify thread-safety across all methods + done := make(chan bool, 15) + errors := make(chan error, 15) + + // Launch goroutines for each method + for i := 0; i < 5; i++ { + go func() { + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + errors <- err + } else { + errors <- errPanicRecovered + } + } + done <- true + }() + + for j := 0; j < 50; j++ { + result := extensionConfig.IsSearchEnabled() + if !result { + errors <- errIsSearchEnabledExpectedTrue + + return + } + } + }() + } + + for i := 0; i < 5; i++ { + go func() { + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + errors <- err + } else { + errors <- errPanicRecovered + } + } + done <- true + }() + + for j := 0; j < 50; j++ { + result := extensionConfig.IsUIEnabled() + if !result { + errors <- errIsUIEnabledExpectedTrue + + return + } + } + }() + } + + for i := 0; i < 5; i++ { + go func() { + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + errors <- err + } else { + errors <- errPanicRecovered + } + } + done <- true + }() + + for j := 0; j < 50; j++ { + result := extensionConfig.AreUserPrefsEnabled() + if !result { + errors <- errAreUserPrefsEnabledExpectedTrue + + return + } + } + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 15; i++ { + <-done + } + + // Check for errors + close(errors) + + for err := range errors { + So(err, ShouldBeNil) + } + }) + + Convey("Test GetSearchCVEConfig()", func() { + testGetterWithNilConfig((*config.ExtensionConfig).GetSearchCVEConfig, nil) + testGetterWithNilSubConfig("Search", (*config.ExtensionConfig).GetSearchCVEConfig, nil) + testGetterWithValidConfig("Search", (*config.ExtensionConfig).GetSearchCVEConfig, buildSearchConfigWithCVE) + }) + + Convey("Test GetScrubInterval()", func() { + testGetterWithNilConfig((*config.ExtensionConfig).GetScrubInterval, 0) + testGetterWithNilSubConfig("Scrub", (*config.ExtensionConfig).GetScrubInterval, 0) + testGetterWithValidConfig("Scrub", (*config.ExtensionConfig).GetScrubInterval, buildScrubConfig) + }) + + Convey("Test GetSyncConfig()", func() { + testGetterWithNilConfig((*config.ExtensionConfig).GetSyncConfig, nil) + testGetterWithValidConfig("Sync", (*config.ExtensionConfig).GetSyncConfig, buildSyncConfig) + }) + + Convey("Test GetMetricsPrometheusConfig()", func() { + testGetterWithNilConfig((*config.ExtensionConfig).GetMetricsPrometheusConfig, nil) + testGetterWithNilSubConfig("Metrics", (*config.ExtensionConfig).GetMetricsPrometheusConfig, nil) + testGetterWithValidConfig("Metrics", (*config.ExtensionConfig).GetMetricsPrometheusConfig, buildMetricsConfig) + }) + + Convey("Test GetEventsConfig()", func() { + testGetterWithNilConfig((*config.ExtensionConfig).GetEventsConfig, nil) + testGetterWithValidConfig("Events", (*config.ExtensionConfig).GetEventsConfig, buildEventsConfig) + }) + }) +} diff --git a/pkg/extensions/extension_events.go b/pkg/extensions/extension_events.go index fffca7b0..4e58ec95 100644 --- a/pkg/extensions/extension_events.go +++ b/pkg/extensions/extension_events.go @@ -12,15 +12,16 @@ import ( ) func NewEventRecorder(config *config.Config, log log.Logger) (events.Recorder, error) { - if !config.IsEventRecorderEnabled() { + // Get extensions config safely + extensionsConfig := config.CopyExtensionsConfig() + if !extensionsConfig.IsEventRecorderEnabled() { log.Info().Msg("events disabled in configuration") return nil, zerr.ErrExtensionNotEnabled } - eventConfig := config.Extensions.Events - - if eventConfig.Sinks == nil || len(eventConfig.Sinks) == 0 { + eventConfig := extensionsConfig.GetEventsConfig() + if eventConfig == nil || eventConfig.Sinks == nil || len(eventConfig.Sinks) == 0 { log.Info().Msg("no sinks provided, skipping events extension setup") return nil, zerr.ErrExtensionNotEnabled diff --git a/pkg/extensions/extension_events_disabled.go b/pkg/extensions/extension_events_disabled.go index 9c5bbf06..81d73d14 100644 --- a/pkg/extensions/extension_events_disabled.go +++ b/pkg/extensions/extension_events_disabled.go @@ -11,7 +11,9 @@ import ( ) func NewEventRecorder(config *config.Config, log log.Logger) (events.Recorder, error) { - if !config.IsEventRecorderEnabled() { + // Get extensions config safely + extensionsConfig := config.CopyExtensionsConfig() + if !extensionsConfig.IsEventRecorderEnabled() { log.Info().Msg("events disabled in configuration") return nil, zerr.ErrExtensionNotEnabled diff --git a/pkg/extensions/extension_image_trust.go b/pkg/extensions/extension_image_trust.go index 6713edf4..1301fe46 100644 --- a/pkg/extensions/extension_image_trust.go +++ b/pkg/extensions/extension_image_trust.go @@ -27,7 +27,9 @@ func IsBuiltWithImageTrustExtension() bool { } func SetupImageTrustRoutes(conf *config.Config, router *mux.Router, metaDB mTypes.MetaDB, log log.Logger) { - if !conf.IsImageTrustEnabled() || (!conf.IsCosignEnabled() && !conf.IsNotationEnabled()) { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsImageTrustEnabled() || + (!extensionsConfig.IsCosignEnabled() && !extensionsConfig.IsNotationEnabled()) { log.Info().Msg("skip enabling the image trust routes as the config prerequisites are not met") return @@ -39,7 +41,7 @@ func SetupImageTrustRoutes(conf *config.Config, router *mux.Router, metaDB mType trust := ImageTrust{Conf: conf, ImageTrustStore: imgTrustStore, Log: log} allowedMethods := zcommon.AllowedMethods(http.MethodPost) - if conf.IsNotationEnabled() { + if extensionsConfig.IsNotationEnabled() { log.Info().Msg("setting up notation route") notationRouter := router.PathPrefix(constants.ExtNotation).Subrouter() @@ -51,7 +53,7 @@ func SetupImageTrustRoutes(conf *config.Config, router *mux.Router, metaDB mType notationRouter.Methods(allowedMethods...).HandlerFunc(trust.HandleNotationCertificateUpload) } - if conf.IsCosignEnabled() { + if extensionsConfig.IsCosignEnabled() { log.Info().Msg("setting up cosign route") cosignRouter := router.PathPrefix(constants.ExtCosign).Subrouter() @@ -153,7 +155,8 @@ func (trust *ImageTrust) HandleNotationCertificateUpload(response http.ResponseW func EnableImageTrustVerification(conf *config.Config, taskScheduler *scheduler.Scheduler, metaDB mTypes.MetaDB, log log.Logger, ) { - if !conf.IsImageTrustEnabled() { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsImageTrustEnabled() { return } @@ -165,7 +168,8 @@ func EnableImageTrustVerification(conf *config.Config, taskScheduler *scheduler. } func SetupImageTrustExtension(conf *config.Config, metaDB mTypes.MetaDB, log log.Logger) error { - if !conf.IsImageTrustEnabled() { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsImageTrustEnabled() { return nil } diff --git a/pkg/extensions/extension_metrics.go b/pkg/extensions/extension_metrics.go index 64beb284..8de32158 100644 --- a/pkg/extensions/extension_metrics.go +++ b/pkg/extensions/extension_metrics.go @@ -13,13 +13,10 @@ import ( ) func EnableMetricsExtension(config *config.Config, log log.Logger, rootDir string) { - if config.IsMetricsEnabled() && - config.Extensions.Metrics.Prometheus != nil { - if config.Extensions.Metrics.Prometheus.Path == "" { - config.Extensions.Metrics.Prometheus.Path = "/metrics" - - log.Warn().Msg("prometheus instrumentation path not set, changing to '/metrics'.") - } + // Get extensions config safely + extensionsConfig := config.CopyExtensionsConfig() + if extensionsConfig.IsMetricsEnabled() { + log.Info().Msg("metrics extension enabled") } else { log.Info().Msg("metrics config not provided, skipping metrics config update") } @@ -30,10 +27,15 @@ func SetupMetricsRoutes(config *config.Config, router *mux.Router, ) { log.Info().Msg("setting up metrics routes") - if config.IsMetricsEnabled() { - extRouter := router.PathPrefix(config.Extensions.Metrics.Prometheus.Path).Subrouter() - extRouter.Use(authnFunc) - extRouter.Use(authzFunc) - extRouter.Methods("GET").Handler(promhttp.Handler()) + // Get extensions config safely + extensionsConfig := config.CopyExtensionsConfig() + if extensionsConfig.IsMetricsEnabled() { + prometheusConfig := extensionsConfig.GetMetricsPrometheusConfig() + if prometheusConfig != nil { + extRouter := router.PathPrefix(prometheusConfig.Path).Subrouter() + extRouter.Use(authnFunc) + extRouter.Use(authzFunc) + extRouter.Methods("GET").Handler(promhttp.Handler()) + } } } diff --git a/pkg/extensions/extension_mgmt.go b/pkg/extensions/extension_mgmt.go index 3da8afb6..db99a87c 100644 --- a/pkg/extensions/extension_mgmt.go +++ b/pkg/extensions/extension_mgmt.go @@ -85,7 +85,8 @@ func (auth Auth) MarshalJSON() ([]byte, error) { } func SetupMgmtRoutes(conf *config.Config, router *mux.Router, log log.Logger) { - if !conf.IsMgmtEnabled() { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsSearchEnabled() { log.Info().Msg("skip enabling the mgmt route as the config prerequisites are not met") return diff --git a/pkg/extensions/extension_scrub.go b/pkg/extensions/extension_scrub.go index 6998eee6..9811d451 100644 --- a/pkg/extensions/extension_scrub.go +++ b/pkg/extensions/extension_scrub.go @@ -4,8 +4,6 @@ package extensions import ( - "time" - "zotregistry.dev/zot/v2/pkg/api/config" "zotregistry.dev/zot/v2/pkg/extensions/scrub" "zotregistry.dev/zot/v2/pkg/log" @@ -18,15 +16,10 @@ import ( func EnableScrubExtension(config *config.Config, log log.Logger, storeController storage.StoreController, sch *scheduler.Scheduler, ) { - if config.Extensions.Scrub != nil && - *config.Extensions.Scrub.Enable { - minScrubInterval, _ := time.ParseDuration("2h") - - if config.Extensions.Scrub.Interval < minScrubInterval { - config.Extensions.Scrub.Interval = minScrubInterval - - log.Warn().Msg("scrub interval set to too-short interval < 2h, changing scrub duration to 2 hours and continuing.") //nolint:lll // gofumpt conflicts with lll - } + // Get extensions config safely + extensionsConfig := config.CopyExtensionsConfig() + if extensionsConfig.IsScrubEnabled() { + scrubInterval := extensionsConfig.GetScrubInterval() processedRepos := make(map[string]struct{}) @@ -36,10 +29,12 @@ func EnableScrubExtension(config *config.Config, log log.Logger, storeController processedRepos: processedRepos, } - sch.SubmitGenerator(generator, config.Extensions.Scrub.Interval, scheduler.LowPriority) + sch.SubmitGenerator(generator, scrubInterval, scheduler.LowPriority) - if config.Storage.SubPaths != nil { - for route := range config.Storage.SubPaths { + // Get storage config safely + storageConfig := config.CopyStorageConfig() + if storageConfig.SubPaths != nil { + for route := range storageConfig.SubPaths { processedRepos := make(map[string]struct{}) generator := &taskGenerator{ @@ -48,7 +43,7 @@ func EnableScrubExtension(config *config.Config, log log.Logger, storeController processedRepos: processedRepos, } - sch.SubmitGenerator(generator, config.Extensions.Scrub.Interval, scheduler.LowPriority) + sch.SubmitGenerator(generator, scrubInterval, scheduler.LowPriority) } } } else { diff --git a/pkg/extensions/extension_search.go b/pkg/extensions/extension_search.go index d45d22c0..8c5458ae 100644 --- a/pkg/extensions/extension_search.go +++ b/pkg/extensions/extension_search.go @@ -33,12 +33,15 @@ func IsBuiltWithSearchExtension() bool { func GetCveScanner(conf *config.Config, storeController storage.StoreController, metaDB mTypes.MetaDB, log log.Logger, ) CveScanner { - if !conf.IsCveScanningEnabled() { + // Get extensions config safely + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsCveScanningEnabled() { return nil } - dbRepository := conf.Extensions.Search.CVE.Trivy.DBRepository - javaDBRepository := conf.Extensions.Search.CVE.Trivy.JavaDBRepository + cveConfig := extensionsConfig.GetSearchCVEConfig() + dbRepository := cveConfig.Trivy.DBRepository + javaDBRepository := cveConfig.Trivy.JavaDBRepository return cveinfo.NewScanner(storeController, metaDB, dbRepository, javaDBRepository, log) } @@ -46,8 +49,11 @@ func GetCveScanner(conf *config.Config, storeController storage.StoreController, func EnableSearchExtension(conf *config.Config, storeController storage.StoreController, metaDB mTypes.MetaDB, taskScheduler *scheduler.Scheduler, cveScanner CveScanner, log log.Logger, ) { - if conf.IsCveScanningEnabled() { - updateInterval := conf.Extensions.Search.CVE.UpdateInterval + // Get extensions config safely + extensionsConfig := conf.CopyExtensionsConfig() + if extensionsConfig.IsCveScanningEnabled() { + cveConfig := extensionsConfig.GetSearchCVEConfig() + updateInterval := cveConfig.UpdateInterval downloadTrivyDB(updateInterval, taskScheduler, cveScanner, log) startScanner(scanInterval, metaDB, taskScheduler, cveScanner, log) @@ -75,7 +81,8 @@ func startScanner(interval time.Duration, metaDB mTypes.MetaDB, sch *scheduler.S func SetupSearchRoutes(conf *config.Config, router *mux.Router, storeController storage.StoreController, metaDB mTypes.MetaDB, cveScanner CveScanner, log log.Logger, ) { - if !conf.IsSearchEnabled() { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsSearchEnabled() { log.Info().Msg("skip enabling the search route as the config prerequisites are not met") return @@ -84,7 +91,7 @@ func SetupSearchRoutes(conf *config.Config, router *mux.Router, storeController log.Info().Msg("setting up search routes") var cveInfo cveinfo.CveInfo - if conf.IsCveScanningEnabled() { + if extensionsConfig.IsCveScanningEnabled() { cveInfo = cveinfo.NewCVEInfo(cveScanner, metaDB, log) } else { cveInfo = nil diff --git a/pkg/extensions/extension_sync.go b/pkg/extensions/extension_sync.go index 17619b39..9ffaf29f 100644 --- a/pkg/extensions/extension_sync.go +++ b/pkg/extensions/extension_sync.go @@ -21,13 +21,18 @@ import ( func EnableSyncExtension(config *config.Config, metaDB mTypes.MetaDB, storeController storage.StoreController, sch *scheduler.Scheduler, log log.Logger, ) (*sync.BaseOnDemand, error) { - if config.Extensions.Sync != nil && *config.Extensions.Sync.Enable { - onDemand := sync.NewOnDemand(log) + // Get extensions config safely + extensionsConfig := config.CopyExtensionsConfig() + httpAddress := config.GetHTTPAddress() + httpPort := config.GetHTTPPort() - for _, registryConfig := range config.Extensions.Sync.Registries { - registryConfig := registryConfig + if extensionsConfig.IsSyncEnabled() { + onDemand := sync.NewOnDemand(log) + syncConfig := extensionsConfig.GetSyncConfig() + + for _, registryConfig := range syncConfig.Registries { if len(registryConfig.URLs) > 1 { - if err := removeSelfURLs(config, ®istryConfig, log); err != nil { + if err := removeSelfURLs(httpAddress, httpPort, ®istryConfig, log); err != nil { return nil, err } } @@ -45,11 +50,12 @@ func EnableSyncExtension(config *config.Config, metaDB mTypes.MetaDB, continue } - tmpDir := config.Extensions.Sync.DownloadDir - credsPath := config.Extensions.Sync.CredentialsFile - clusterCfg := config.Cluster + tmpDir := syncConfig.DownloadDir + credsPath := syncConfig.CredentialsFile + // Get cluster config safely + clusterConfig := config.CopyClusterConfig() - service, err := sync.New(registryConfig, credsPath, clusterCfg, tmpDir, storeController, metaDB, log) + service, err := sync.New(registryConfig, credsPath, clusterConfig, tmpDir, storeController, metaDB, log) if err != nil { log.Error().Err(err).Msg("failed to initialize sync extension") @@ -102,10 +108,9 @@ func getLocalIPs() ([]string, error) { return localIPs, nil } -func removeSelfURLs(config *config.Config, registryConfig *syncconf.RegistryConfig, log log.Logger) error { +func removeSelfURLs(httpAddress, httpPort string, registryConfig *syncconf.RegistryConfig, log log.Logger) error { // get IP from config - port := config.HTTP.Port - selfAddress := net.JoinHostPort(config.HTTP.Address, port) + selfAddress := net.JoinHostPort(httpAddress, httpPort) // get all local IPs from interfaces localIPs, err := getLocalIPs() @@ -148,8 +153,8 @@ func removeSelfURLs(config *config.Config, registryConfig *syncconf.RegistryConf for _, localIP := range localIPs { // if ip resolved from hostname/dns is equal with any local ip for _, ip := range ips { - if (ip.IsLoopback() && (url.Port() == port)) || - (net.JoinHostPort(ip.String(), url.Port()) == net.JoinHostPort(localIP, port)) { + if (ip.IsLoopback() && (url.Port() == httpPort)) || + (net.JoinHostPort(ip.String(), url.Port()) == net.JoinHostPort(localIP, httpPort)) { registryConfig.URLs = append(registryConfig.URLs[:idx], registryConfig.URLs[idx+1:]...) removed = true diff --git a/pkg/extensions/extension_ui.go b/pkg/extensions/extension_ui.go index a4aa2741..cf4880e7 100644 --- a/pkg/extensions/extension_ui.go +++ b/pkg/extensions/extension_ui.go @@ -63,7 +63,8 @@ func addUISecurityHeaders(h http.Handler) http.HandlerFunc { //nolint:varnamelen func SetupUIRoutes(conf *config.Config, router *mux.Router, log log.Logger, ) { - if !conf.IsUIEnabled() { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsUIEnabled() { log.Info().Msg("skip enabling the ui route as the config prerequisites are not met") return diff --git a/pkg/extensions/extension_userprefs.go b/pkg/extensions/extension_userprefs.go index 7823e135..bad09662 100644 --- a/pkg/extensions/extension_userprefs.go +++ b/pkg/extensions/extension_userprefs.go @@ -29,7 +29,8 @@ func IsBuiltWithUserPrefsExtension() bool { func SetupUserPreferencesRoutes(conf *config.Config, router *mux.Router, metaDB mTypes.MetaDB, log log.Logger, ) { - if !conf.AreUserPrefsEnabled() { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.AreUserPrefsEnabled() { log.Info().Msg("skip enabling the user preferences route as the config prerequisites are not met") return diff --git a/pkg/extensions/extensions_test.go b/pkg/extensions/extensions_test.go index 401f600d..e0c7c667 100644 --- a/pkg/extensions/extensions_test.go +++ b/pkg/extensions/extensions_test.go @@ -107,8 +107,7 @@ func TestMetricsExtension(t *testing.T) { data, _ := os.ReadFile(logFile.Name()) - So(string(data), ShouldContainSubstring, - "prometheus instrumentation path not set, changing to '/metrics'.") + So(string(data), ShouldContainSubstring, "metrics extension enabled") }) } diff --git a/pkg/extensions/get_extensions.go b/pkg/extensions/get_extensions.go index 1fe54c43..f59894af 100644 --- a/pkg/extensions/get_extensions.go +++ b/pkg/extensions/get_extensions.go @@ -16,23 +16,24 @@ func GetExtensions(config *config.Config) distext.ExtensionList { endpoints := []string{} extensions := []distext.Extension{} - if config.IsNotationEnabled() && IsBuiltWithImageTrustExtension() { + extensionsConfig := config.CopyExtensionsConfig() + if extensionsConfig.IsNotationEnabled() && IsBuiltWithImageTrustExtension() { endpoints = append(endpoints, constants.FullNotation) } - if config.IsCosignEnabled() && IsBuiltWithImageTrustExtension() { + if extensionsConfig.IsCosignEnabled() && IsBuiltWithImageTrustExtension() { endpoints = append(endpoints, constants.FullCosign) } - if config.IsSearchEnabled() && IsBuiltWithSearchExtension() { + if extensionsConfig.IsSearchEnabled() && IsBuiltWithSearchExtension() { endpoints = append(endpoints, constants.FullSearchPrefix) } - if config.AreUserPrefsEnabled() && IsBuiltWithUserPrefsExtension() { + if extensionsConfig.AreUserPrefsEnabled() && IsBuiltWithUserPrefsExtension() { endpoints = append(endpoints, constants.FullUserPrefs) } - if config.IsMgmtEnabled() && IsBuiltWithMGMTExtension() { + if extensionsConfig.IsSearchEnabled() && IsBuiltWithMGMTExtension() { endpoints = append(endpoints, constants.FullMgmt) } diff --git a/pkg/extensions/monitoring/common.go b/pkg/extensions/monitoring/common.go index 5968e43f..e9be85d4 100644 --- a/pkg/extensions/monitoring/common.go +++ b/pkg/extensions/monitoring/common.go @@ -14,6 +14,8 @@ type MetricServer interface { ForceSendMetric(interface{}) ReceiveMetrics() interface{} IsEnabled() bool + // Stop gracefully shuts down the metrics server + Stop() } func GetDirSize(path string) (int64, error) { diff --git a/pkg/extensions/monitoring/extension.go b/pkg/extensions/monitoring/extension.go index b05d1f6c..d9a9b66f 100644 --- a/pkg/extensions/monitoring/extension.go +++ b/pkg/extensions/monitoring/extension.go @@ -137,6 +137,11 @@ type metricServer struct { log log.Logger } +// Stop gracefully shuts down the metrics server (no-op for this implementation). +func (ms *metricServer) Stop() { + // This is a no-op implementation for the disabled metrics server +} + func GetDefaultBuckets() []float64 { return []float64{.05, .5, 1, 5, 30, 60, 600} } diff --git a/pkg/extensions/monitoring/minimal.go b/pkg/extensions/monitoring/minimal.go index 4b90554c..5b8b2ec2 100644 --- a/pkg/extensions/monitoring/minimal.go +++ b/pkg/extensions/monitoring/minimal.go @@ -49,6 +49,7 @@ type metricServer struct { bucketsF2S map[float64]string // float64 to string conversion of buckets label log log.Logger lock *sync.RWMutex + stopChan chan struct{} // Channel to signal shutdown } type MetricsInfo struct { @@ -142,19 +143,35 @@ func (ms *metricServer) IsEnabled() bool { return ms.enabled } +// Stop gracefully shuts down the metrics server. +func (ms *metricServer) Stop() { + close(ms.stopChan) +} + func (ms *metricServer) Run() { sendAfter := make(chan time.Duration, 1) // periodically send a notification to the metric server to check if we can disable metrics go func() { for { - t := metricsScrapeCheckInterval - time.Sleep(t) - sendAfter <- t + select { + case <-ms.stopChan: + return + default: + t := metricsScrapeCheckInterval + time.Sleep(t) + select { + case sendAfter <- t: + case <-ms.stopChan: + return + } + } } }() for { select { + case <-ms.stopChan: + return case <-ms.cacheChan: ms.lastCheck = time.Now() // make a copy of cache values to prevent data race @@ -239,6 +256,7 @@ func NewMetricsServer(enabled bool, log log.Logger) MetricServer { bucketsF2S: bucketsFloat2String, log: log, lock: &sync.RWMutex{}, + stopChan: make(chan struct{}), } go ms.Run() diff --git a/pkg/extensions/monitoring/monitoring_test.go b/pkg/extensions/monitoring/monitoring_test.go index d179717b..20698d4f 100644 --- a/pkg/extensions/monitoring/monitoring_test.go +++ b/pkg/extensions/monitoring/monitoring_test.go @@ -48,6 +48,11 @@ func TestExtensionMetrics(t *testing.T) { ctlr := api.NewController(conf) So(ctlr, ShouldNotBeNil) + // Write image before starting controller to avoid race condition with garbage collection + srcStorageCtlr := ociutils.GetDefaultStoreController(rootDir, ctlr.Log) + err := WriteImageToFileSystem(CreateDefaultImage(), "alpine", "0.0.1", srcStorageCtlr) + So(err, ShouldBeNil) + cm := test.NewControllerManager(ctlr) cm.StartAndWait(port) defer cm.StopServer() @@ -64,10 +69,6 @@ func TestExtensionMetrics(t *testing.T) { monitoring.IncDownloadCounter(ctlr.Metrics, "alpine") monitoring.IncUploadCounter(ctlr.Metrics, "alpine") - srcStorageCtlr := ociutils.GetDefaultStoreController(rootDir, ctlr.Log) - err := WriteImageToFileSystem(CreateDefaultImage(), "alpine", "0.0.1", srcStorageCtlr) - So(err, ShouldBeNil) - monitoring.SetStorageUsage(ctlr.Metrics, rootDir, "alpine") monitoring.ObserveStorageLockLatency(ctlr.Metrics, time.Millisecond, rootDir, "RWLock") diff --git a/pkg/extensions/search/search_test.go b/pkg/extensions/search/search_test.go index 7ae097bc..4c76d527 100644 --- a/pkg/extensions/search/search_test.go +++ b/pkg/extensions/search/search_test.go @@ -602,7 +602,7 @@ func TestRepoListWithNewestImage(t *testing.T) { // Delete config blob and try. err = os.Remove(path.Join(subRootDir, "a/zot-test/blobs/sha256", configDigest.Encoded())) - if err != nil { + if err != nil && !os.IsNotExist(err) { panic(err) } @@ -614,7 +614,7 @@ func TestRepoListWithNewestImage(t *testing.T) { err = os.Remove(path.Join(subRootDir, "a/zot-test/blobs/sha256", manifestDigest.Encoded())) - if err != nil { + if err != nil && !os.IsNotExist(err) { panic(err) } @@ -625,7 +625,7 @@ func TestRepoListWithNewestImage(t *testing.T) { So(resp.StatusCode(), ShouldEqual, 200) err = os.Remove(path.Join(rootDir, "zot-test/blobs/sha256", configDigest.Encoded())) - if err != nil { + if err != nil && !os.IsNotExist(err) { panic(err) } @@ -637,7 +637,7 @@ func TestRepoListWithNewestImage(t *testing.T) { // Delete manifest blob also and try err = os.Remove(path.Join(rootDir, "zot-test/blobs/sha256", manifestDigest.Encoded())) - if err != nil { + if err != nil && !os.IsNotExist(err) { panic(err) } diff --git a/pkg/extensions/sync/sync_internal_test.go b/pkg/extensions/sync/sync_internal_test.go index b8773d0b..fbc285b7 100644 --- a/pkg/extensions/sync/sync_internal_test.go +++ b/pkg/extensions/sync/sync_internal_test.go @@ -175,8 +175,6 @@ func TestService(t *testing.T) { onDemand.Add(service) ctx := context.Background() - // === IMAGE SYNC CONTINUE PATH TEST === - // Step 1: Verify empty requestStore initially initialImageCount := 0 onDemand.requestStore.Range(func(key, value interface{}) bool { @@ -219,8 +217,6 @@ func TestService(t *testing.T) { So(exists, ShouldBeTrue) So(value, ShouldEqual, struct{}{}) // Should still be pre-populated value - // === REFERRER SYNC CONTINUE PATH TEST === - // Step 7: Verify current state before referrer test - we should have 1 request initialReferrerCount := 0 onDemand.requestStore.Range(func(key, value interface{}) bool { diff --git a/pkg/extensions/sync/sync_test.go b/pkg/extensions/sync/sync_test.go index 997dbe4c..603cc657 100644 --- a/pkg/extensions/sync/sync_test.go +++ b/pkg/extensions/sync/sync_test.go @@ -814,11 +814,14 @@ func TestOnDemandWithScaleOutCluster(t *testing.T) { Registries: []syncconf.RegistryConfig{syncRegistryConfig}, } + // Get dynamic ports for cluster members + clusterPorts := test.GetFreePorts(2) + // cluster config for member 1. clusterCfgDownstream1 := config.ClusterConfig{ Members: []string{ - "127.0.0.1:43222", - "127.0.0.1:43223", + "127.0.0.1:" + clusterPorts[0], + "127.0.0.1:" + clusterPorts[1], }, HashKey: "loremipsumdolors", } @@ -827,11 +830,11 @@ func TestOnDemandWithScaleOutCluster(t *testing.T) { clusterCfgDownstream2 := clusterCfgDownstream1 dctrl1, dctrl1BaseURL, destDir1, dstClient1 := makeInsecureDownstreamServerFixedPort( - t, "43222", syncConfig, &clusterCfgDownstream1) + t, clusterPorts[0], syncConfig, &clusterCfgDownstream1) dctrl1Scm := test.NewControllerManager(dctrl1) dctrl2, dctrl2BaseURL, destDir2, dstClient2 := makeInsecureDownstreamServerFixedPort( - t, "43223", syncConfig, &clusterCfgDownstream2) + t, clusterPorts[1], syncConfig, &clusterCfgDownstream2) dctrl2Scm := test.NewControllerManager(dctrl2) dctrl1Scm.StartAndWait(dctrl1.Config.HTTP.Port) @@ -983,11 +986,14 @@ func TestOnDemandWithScaleOutClusterWithReposNotAddedForSync(t *testing.T) { Registries: []syncconf.RegistryConfig{syncRegistryConfig}, } + // Get dynamic ports for cluster members + clusterPorts := test.GetFreePorts(2) + // cluster config for member 1. clusterCfgDownstream1 := config.ClusterConfig{ Members: []string{ - "127.0.0.1:43222", - "127.0.0.1:43223", + "127.0.0.1:" + clusterPorts[0], + "127.0.0.1:" + clusterPorts[1], }, HashKey: "loremipsumdolors", } @@ -996,11 +1002,11 @@ func TestOnDemandWithScaleOutClusterWithReposNotAddedForSync(t *testing.T) { clusterCfgDownstream2 := clusterCfgDownstream1 dctrl1, dctrl1BaseURL, destDir1, dstClient1 := makeInsecureDownstreamServerFixedPort( - t, "43222", syncConfig, &clusterCfgDownstream1) + t, clusterPorts[0], syncConfig, &clusterCfgDownstream1) dctrl1Scm := test.NewControllerManager(dctrl1) dctrl2, dctrl2BaseURL, destDir2, dstClient2 := makeInsecureDownstreamServerFixedPort( - t, "43223", syncConfig, &clusterCfgDownstream2) + t, clusterPorts[1], syncConfig, &clusterCfgDownstream2) dctrl2Scm := test.NewControllerManager(dctrl2) dctrl1Scm.StartAndWait(dctrl1.Config.HTTP.Port) @@ -1878,15 +1884,20 @@ func TestPeriodicallyWithScaleOutCluster(t *testing.T) { // zot-test is managed by member index 1. // zot-cve-test is managed by member index 0. // zot-alpine-test is managed by member index 1. + + // Get dynamic ports for cluster members + clusterPorts := test.GetFreePorts(2) + clusterCfg := config.ClusterConfig{ Members: []string{ - "127.0.0.1:100", - "127.0.0.1:42000", + "127.0.0.1:" + clusterPorts[0], + "127.0.0.1:" + clusterPorts[1], }, HashKey: "loremipsumdolors", } - dctlr, destBaseURL, destDir, destClient := makeInsecureDownstreamServerFixedPort(t, "42000", syncConfig, &clusterCfg) + dctlr, destBaseURL, destDir, destClient := makeInsecureDownstreamServerFixedPort(t, + clusterPorts[1], syncConfig, &clusterCfg) dcm := test.NewControllerManager(dctlr) dcm.StartAndWait(dctlr.Config.HTTP.Port) diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index ea2ab9ed..08efc18d 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -549,8 +549,9 @@ func (scheduler *Scheduler) SubmitGenerator(taskGenerator TaskGenerator, interva } func getNumWorkers(cfg *config.Config) int { - if cfg.Scheduler != nil && cfg.Scheduler.NumWorkers != 0 { - return cfg.Scheduler.NumWorkers + schedulerConfig := cfg.CopySchedulerConfig() + if schedulerConfig != nil && schedulerConfig.NumWorkers != 0 { + return schedulerConfig.NumWorkers } return runtime.NumCPU() * NumWorkersMultiplier diff --git a/pkg/storage/gc/gc_test.go b/pkg/storage/gc/gc_test.go index 592a5fb5..bb2ddc85 100644 --- a/pkg/storage/gc/gc_test.go +++ b/pkg/storage/gc/gc_test.go @@ -63,6 +63,7 @@ func TestGarbageCollectAndRetentionMetaDB(t *testing.T) { audit := zlog.NewAuditLogger("debug", "/dev/null") metrics := monitoring.NewMetricsServer(false, log) + defer metrics.Stop() // Clean up metrics server to prevent resource leaks trueVal := true @@ -1317,6 +1318,7 @@ func TestGarbageCollectDeletion(t *testing.T) { audit := zlog.NewAuditLogger("debug", "/dev/null") metrics := monitoring.NewMetricsServer(false, log) + defer metrics.Stop() // Clean up metrics server to prevent resource leaks trueVal := true falseVal := false @@ -1755,6 +1757,7 @@ func TestGarbageCollectAndRetentionNoMetaDB(t *testing.T) { audit := zlog.NewAuditLogger("debug", "/dev/null") metrics := monitoring.NewMetricsServer(false, log) + defer metrics.Stop() // Clean up metrics server to prevent resource leaks trueVal := true diff --git a/pkg/test/common/utils.go b/pkg/test/common/utils.go index 507043d9..5f9ff390 100644 --- a/pkg/test/common/utils.go +++ b/pkg/test/common/utils.go @@ -1,14 +1,18 @@ package common import ( + "bytes" "errors" "fmt" + "io" "math/rand" "net/http" "net/url" "os" "path" "strconv" + "strings" + "sync" "time" "github.com/phayes/freeport" @@ -143,6 +147,23 @@ func GetFreePort() string { return strconv.Itoa(port) } +// GetFreePorts returns multiple unique free ports, useful for cluster tests. +func GetFreePorts(count int) []string { + // Use the freeport library's GetFreePorts function which guarantees uniqueness + intPorts, err := freeport.GetFreePorts(count) + if err != nil { + panic(err) + } + + // Convert to strings + ports := make([]string, count) + for i, port := range intPorts { + ports[i] = strconv.Itoa(port) + } + + return ports +} + func GetBaseURL(port string) string { return fmt.Sprintf(BaseURL, port) } @@ -233,3 +254,78 @@ func ContainSameElements[T comparable](list1, list2 []T) bool { return true } + +// ThreadSafeLogBuffer is a thread-safe wrapper around bytes.Buffer for concurrent log capture. +type ThreadSafeLogBuffer struct { + buffer *bytes.Buffer + mutex sync.RWMutex +} + +// NewThreadSafeLogBuffer creates a new thread-safe log buffer. +func NewThreadSafeLogBuffer() *ThreadSafeLogBuffer { + return &ThreadSafeLogBuffer{ + buffer: &bytes.Buffer{}, + } +} + +// Write implements io.Writer interface with thread safety. +func (tsb *ThreadSafeLogBuffer) Write(p []byte) (int, error) { + tsb.mutex.Lock() + defer tsb.mutex.Unlock() + + return tsb.buffer.Write(p) +} + +// String returns the buffer contents as a string with thread safety. +func (tsb *ThreadSafeLogBuffer) String() string { + tsb.mutex.RLock() + defer tsb.mutex.RUnlock() + + return tsb.buffer.String() +} + +// WaitForLogMessages waits for a specific number of log messages to appear in the log buffer +// within the given timeout. This is useful for verifying goroutine termination or other +// asynchronous operations that log specific messages. +// +// Parameters: +// - logBuffer: A ThreadSafeLogBuffer that captures log output +// - message: The log message to search for (e.g., "htpasswd watcher terminating...") +// - minCount: Minimum number of occurrences to wait for +// - timeout: Maximum time to wait for the messages +// +// Returns: +// - true if at least minCount messages were found within the timeout +// - false if the timeout was reached before finding enough messages +func WaitForLogMessages(logBuffer *ThreadSafeLogBuffer, message string, minCount int, timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + logOutput := logBuffer.String() + actualCount := strings.Count(logOutput, message) + + if actualCount >= minCount { + return true + } + + time.Sleep(10 * time.Millisecond) + } + + return false +} + +// CreateLogCapturingWriter creates a multi-writer that captures log output to a thread-safe buffer +// while also writing to the original writer (typically os.Stdout). This is useful for +// tests that need to programmatically verify log messages. +// +// Parameters: +// - originalWriter: The original writer to continue writing to (e.g., os.Stdout) +// +// Returns: +// - A ThreadSafeLogBuffer that captures the log output +// - An io.Writer that writes to both the original writer and the buffer +func CreateLogCapturingWriter(originalWriter io.Writer) (*ThreadSafeLogBuffer, io.Writer) { + logBuffer := NewThreadSafeLogBuffer() + multiWriter := io.MultiWriter(originalWriter, logBuffer) + + return logBuffer, multiWriter +} diff --git a/pkg/test/common/utils_test.go b/pkg/test/common/utils_test.go index 28df644f..f90558db 100644 --- a/pkg/test/common/utils_test.go +++ b/pkg/test/common/utils_test.go @@ -58,3 +58,143 @@ func TestControllerManager(t *testing.T) { So(func() { ctlrManager.RunServer() }, ShouldPanic) }) } + +func TestWaitForLogMessages(t *testing.T) { + Convey("Test WaitForLogMessages", t, func() { + Convey("should return true when message count reaches minimum", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write some log messages + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("Server started successfully\n")) + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("Processing request\n")) + _, _ = logBuffer.Write([]byte("Starting server\n")) + + // Wait for "Starting server" message to appear at least 3 times + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 3, 100*time.Millisecond) + + So(result, ShouldBeTrue) + }) + + Convey("should return false when message count never reaches minimum", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write some log messages (only 1 occurrence of target message) + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("Server started successfully\n")) + _, _ = logBuffer.Write([]byte("Processing request\n")) + + // Wait for "Starting server" message to appear at least 3 times + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 3, 50*time.Millisecond) + + So(result, ShouldBeFalse) + }) + + Convey("should return true immediately when count already meets requirement", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write messages before calling WaitForLogMessages + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("Starting server\n")) + + // Wait for "Starting server" message to appear at least 3 times + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 3, 100*time.Millisecond) + + So(result, ShouldBeTrue) + }) + + Convey("should handle empty log buffer", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Wait for any message in empty buffer + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 1, 50*time.Millisecond) + + So(result, ShouldBeFalse) + }) + + Convey("should handle partial message matches", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write messages with partial matches + _, _ = logBuffer.Write([]byte("Starting server process\n")) + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("Starting server instance\n")) + + // Wait for exact "Starting server" message (not partial matches) + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 2, 100*time.Millisecond) + + So(result, ShouldBeTrue) + }) + + Convey("should timeout after specified duration", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write only one message + _, _ = logBuffer.Write([]byte("Starting server\n")) + + // Wait for 3 occurrences with short timeout + start := time.Now() + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 3, 10*time.Millisecond) + duration := time.Since(start) + + So(result, ShouldBeFalse) + So(duration, ShouldBeGreaterThanOrEqualTo, 10*time.Millisecond) + So(duration, ShouldBeLessThan, 50*time.Millisecond) // Should timeout quickly + }) + + Convey("should handle concurrent writes", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Simulate concurrent writes + go func() { + for i := 0; i < 5; i++ { + _, _ = logBuffer.Write([]byte("Starting server\n")) + + time.Sleep(5 * time.Millisecond) + } + }() + + // Wait for messages to appear + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 3, 100*time.Millisecond) + + So(result, ShouldBeTrue) + }) + + Convey("should handle case-sensitive message matching", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write messages with different cases + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("starting server\n")) + _, _ = logBuffer.Write([]byte("STARTING SERVER\n")) + + // Wait for exact case match + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 2, 100*time.Millisecond) + + So(result, ShouldBeFalse) // Only 1 exact match + }) + + Convey("should handle zero minimum count", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Wait for 0 occurrences (should always return true) + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 0, 100*time.Millisecond) + + So(result, ShouldBeTrue) + }) + + Convey("should handle very short timeout", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write a message + _, _ = logBuffer.Write([]byte("Starting server\n")) + + // Wait with very short timeout + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 1, 1*time.Millisecond) + + So(result, ShouldBeTrue) // Should find it immediately + }) + }) +}