diff --git a/go.mod b/go.mod index 74e91f05..b4448533 100644 --- a/go.mod +++ b/go.mod @@ -69,6 +69,7 @@ require ( github.com/stretchr/testify v1.11.1 github.com/swaggo/http-swagger v1.3.4 github.com/swaggo/swag v1.16.6 + github.com/tiendc/go-deepcopy v1.7.1 github.com/vektah/gqlparser/v2 v2.5.30 github.com/zitadel/oidc/v3 v3.45.0 go.etcd.io/bbolt v1.4.3 diff --git a/go.sum b/go.sum index 14e82bef..026ec3b3 100644 --- a/go.sum +++ b/go.sum @@ -2063,6 +2063,8 @@ github.com/theupdateframework/go-tuf/v2 v2.1.1 h1:OWcoHItwsGO+7m0wLa7FDWPR4oB1cj github.com/theupdateframework/go-tuf/v2 v2.1.1/go.mod h1:V675cQGhZONR0OGQ8r1feO0uwtsTBYPDWHzAAPn5rjE= github.com/theupdateframework/notary v0.7.0 h1:QyagRZ7wlSpjT5N2qQAh/pN+DVqgekv4DzbAiAiEL3c= github.com/theupdateframework/notary v0.7.0/go.mod h1:c9DRxcmhHmVLDay4/2fUYdISnHqbFDGRSlXPO0AhYWw= +github.com/tiendc/go-deepcopy v1.7.1 h1:LnubftI6nYaaMOcaz0LphzwraqN8jiWTwm416sitff4= +github.com/tiendc/go-deepcopy v1.7.1/go.mod h1:4bKjNC2r7boYOkD2IOuZpYjmlDdzjbpTRyCx+goBCJQ= github.com/tink-crypto/tink-go-awskms/v2 v2.1.0 h1:N9UxlsOzu5mttdjhxkDLbzwtEecuXmlxZVo/ds7JKJI= github.com/tink-crypto/tink-go-awskms/v2 v2.1.0/go.mod h1:PxSp9GlOkKL9rlybW804uspnHuO9nbD98V/fDX4uSis= github.com/tink-crypto/tink-go-gcpkms/v2 v2.2.0 h1:3B9i6XBXNTRspfkTC0asN5W0K6GhOSgcujNiECNRNb0= diff --git a/pkg/api/authn.go b/pkg/api/authn.go index 7a6c7aed..cb1b3a84 100644 --- a/pkg/api/authn.go +++ b/pkg/api/authn.go @@ -55,7 +55,8 @@ func AuthHandler(ctlr *Controller) mux.MiddlewareFunc { log: ctlr.Log, } - if ctlr.Config.IsBearerAuthEnabled() { + authConfig := ctlr.Config.CopyAuthConfig() + if authConfig.IsBearerAuthEnabled() { return bearerAuthHandler(ctlr) } @@ -103,6 +104,12 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce ) (bool, error) { cookieStore := ctlr.CookieStore + // Get auth config once to avoid multiple calls + authConfig := ctlr.Config.CopyAuthConfig() + if authConfig == nil { + return false, nil + } + identity, passphrase, err := getUsernamePasswordBasicAuth(request) if err != nil { ctlr.Log.Error().Err(err).Msg("failed to parse authorization header") @@ -116,7 +123,8 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce // Process request var groups []string - if ctlr.Config.HTTP.AccessControl != nil { + accessControl := ctlr.Config.CopyAccessControlConfig() + if accessControl != nil { ac := NewAccessController(ctlr.Config) groups = ac.getUserGroups(identity) } @@ -145,13 +153,14 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce } // next, LDAP if configured (network-based which can lose connectivity) - if ctlr.Config.HTTP.Auth != nil && ctlr.Config.HTTP.Auth.LDAP != nil { + if authConfig.IsLdapAuthEnabled() { ok, _, ldapgroups, err := amw.ldapClient.Authenticate(identity, passphrase) if ok && err == nil { // Process request var groups []string - if ctlr.Config.HTTP.AccessControl != nil { + accessControl := ctlr.Config.CopyAccessControlConfig() + if accessControl != nil { ac := NewAccessController(ctlr.Config) groups = ac.getUserGroups(identity) } @@ -181,7 +190,7 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce } // last try API keys - if ctlr.Config.IsAPIKeyEnabled() { + if authConfig.IsAPIKeyEnabled() { apiKey := passphrase if !strings.HasPrefix(apiKey, constants.APIKeysPrefix) { @@ -248,16 +257,20 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce } func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFunc { //nolint: gocyclo + // Get auth config once to avoid multiple calls + authConfig := ctlr.Config.CopyAuthConfig() + // no password based authN, if neither LDAP nor HTTP BASIC is enabled - if !ctlr.Config.IsBasicAuthnEnabled() { + if !authConfig.IsBasicAuthnEnabled() { return noPasswdAuth(ctlr) } - delay := ctlr.Config.HTTP.Auth.FailDelay + delay := authConfig.GetFailDelay() + realm := ctlr.Config.GetRealm() // ldap and htpasswd based authN - if ctlr.Config.IsLdapAuthEnabled() { - ldapConfig := ctlr.Config.HTTP.Auth.LDAP + if authConfig.IsLdapAuthEnabled() { + ldapConfig := authConfig.LDAP ctlr.LDAPClient = &LDAPClient{ Host: ldapConfig.Address, @@ -278,17 +291,17 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun amw.ldapClient = ctlr.LDAPClient - if ctlr.Config.HTTP.Auth.LDAP.CACert != "" { - caCert, err := os.ReadFile(ctlr.Config.HTTP.Auth.LDAP.CACert) + if authConfig.LDAP.CACert != "" { + caCert, err := os.ReadFile(authConfig.LDAP.CACert) if err != nil { - amw.log.Panic().Err(err).Str("caCert", ctlr.Config.HTTP.Auth.LDAP.CACert). + amw.log.Panic().Err(err).Str("caCert", authConfig.LDAP.CACert). Msg("failed to read caCert") } caCertPool := x509.NewCertPool() if !caCertPool.AppendCertsFromPEM(caCert) { - amw.log.Panic().Err(zerr.ErrBadCACert).Str("caCert", ctlr.Config.HTTP.Auth.LDAP.CACert). + amw.log.Panic().Err(zerr.ErrBadCACert).Str("caCert", authConfig.LDAP.CACert). Msg("failed to read caCert") } @@ -297,7 +310,7 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun // default to system cert pool caCertPool, err := x509.SystemCertPool() if err != nil { - amw.log.Panic().Err(zerr.ErrBadCACert).Str("caCert", ctlr.Config.HTTP.Auth.LDAP.CACert). + amw.log.Panic().Err(zerr.ErrBadCACert).Str("caCert", authConfig.LDAP.CACert). Msg("failed to get system cert pool") } @@ -305,26 +318,26 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun } } - if ctlr.Config.IsHtpasswdAuthEnabled() { - err := amw.htpasswd.Reload(ctlr.Config.HTTP.Auth.HTPasswd.Path) + if authConfig.IsHtpasswdAuthEnabled() { + err := amw.htpasswd.Reload(authConfig.HTPasswd.Path) if err != nil { - amw.log.Panic().Err(err).Str("credsFile", ctlr.Config.HTTP.Auth.HTPasswd.Path). + amw.log.Panic().Err(err).Str("credsFile", authConfig.HTPasswd.Path). Msg("failed to open creds-file") } } // openid based authN - if ctlr.Config.IsOpenIDAuthEnabled() { + if authConfig.IsOpenIDAuthEnabled() { ctlr.RelyingParties = make(map[string]rp.RelyingParty) - for provider := range ctlr.Config.HTTP.Auth.OpenID.Providers { + for provider := range authConfig.OpenID.Providers { if config.IsOpenIDSupported(provider) { - rp := NewRelyingPartyOIDC(context.TODO(), ctlr.Config, provider, ctlr.Config.HTTP.Auth.SessionHashKey, - ctlr.Config.HTTP.Auth.SessionEncryptKey, ctlr.Log) + rp := NewRelyingPartyOIDC(context.TODO(), ctlr.Config, provider, authConfig.SessionHashKey, + authConfig.SessionEncryptKey, ctlr.Log) ctlr.RelyingParties[provider] = rp } else if config.IsOauth2Supported(provider) { - rp := NewRelyingPartyGithub(ctlr.Config, provider, ctlr.Config.HTTP.Auth.SessionHashKey, - ctlr.Config.HTTP.Auth.SessionEncryptKey, ctlr.Log) + rp := NewRelyingPartyGithub(ctlr.Config, provider, authConfig.SessionHashKey, + authConfig.SessionEncryptKey, ctlr.Log) ctlr.RelyingParties[provider] = rp } } @@ -340,7 +353,10 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun } isMgmtRequested := request.RequestURI == constants.FullMgmt - allowAnonymous := ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() + + // Get access control config safely + accessControlConfig := ctlr.Config.CopyAccessControlConfig() + allowAnonymous := accessControlConfig != nil && accessControlConfig.AnonymousPolicyExists() // build user access control info userAc := reqCtx.NewUserAccessControl() @@ -370,7 +386,7 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun if errors.Is(err, zerr.ErrUserDataNotFound) { ctlr.Log.Err(err).Msg("failed to find user profile in DB") - authFail(response, request, ctlr.Config.HTTP.Realm, delay) + authFail(response, request, realm, delay) } response.WriteHeader(http.StatusInternalServerError) @@ -397,22 +413,25 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun return } - authFail(response, request, ctlr.Config.HTTP.Realm, delay) + authFail(response, request, realm, delay) }) } } func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc { + // Get auth config safely + authConfig := ctlr.Config.CopyAuthConfig() + // although the configuration option is called 'cert', this function will also parse a public key directly // see https://github.com/project-zot/zot/issues/3173 for info - publicKey, err := loadPublicKeyFromFile(ctlr.Config.HTTP.Auth.Bearer.Cert) + publicKey, err := loadPublicKeyFromFile(authConfig.Bearer.Cert) if err != nil { ctlr.Log.Panic().Err(err).Msg("failed to load public key for bearer authentication") } authorizer := NewBearerAuthorizer( - ctlr.Config.HTTP.Auth.Bearer.Realm, - ctlr.Config.HTTP.Auth.Bearer.Service, + authConfig.Bearer.Realm, + authConfig.Bearer.Service, publicKey, ) @@ -509,7 +528,11 @@ func noPasswdAuth(ctlr *Controller) mux.MiddlewareFunc { } if ctlr.Config.IsMTLSAuthEnabled() && userAc.IsAnonymous() { - authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + authConfig := ctlr.Config.CopyAuthConfig() + failDelay := authConfig.GetFailDelay() + realm := ctlr.Config.GetRealm() + + authFail(response, request, realm, failDelay) return } diff --git a/pkg/api/authz.go b/pkg/api/authz.go index 81292664..0c6475fa 100644 --- a/pkg/api/authz.go +++ b/pkg/api/authz.go @@ -42,16 +42,20 @@ type AccessController struct { } func NewAccessController(conf *config.Config) *AccessController { - if conf.HTTP.AccessControl == nil { + // Get access control config safely + accessControlConfig := conf.CopyAccessControlConfig() + logConfig := conf.CopyLogConfig() + + if accessControlConfig == nil { return &AccessController{ Config: &config.AccessControlConfig{}, - Log: log.NewLogger(conf.Log.Level, conf.Log.Output), + Log: log.NewLogger(logConfig.Level, logConfig.Output), } } return &AccessController{ - Config: conf.HTTP.AccessControl, - Log: log.NewLogger(conf.Log.Level, conf.Log.Output), + Config: accessControlConfig, + Log: log.NewLogger(logConfig.Level, logConfig.Output), } } @@ -117,14 +121,17 @@ func (ac *AccessController) can(userAc *reqCtx.UserAccessControl, action, reposi username := userAc.GetUsername() // check matched repo based policy - pg, ok := ac.Config.Repositories[longestMatchedPattern] + repositories := ac.Config.GetRepositories() + pg, ok := repositories[longestMatchedPattern] + if ok { can = ac.isPermitted(userGroups, username, action, pg) } // check admins based policy if !can { - if ac.isAdmin(username, userGroups) && common.Contains(ac.Config.AdminPolicy.Actions, action) { + adminPolicy := ac.Config.GetAdminPolicy() + if ac.isAdmin(username, userGroups) && common.Contains(adminPolicy.Actions, action) { can = true } } @@ -134,7 +141,8 @@ func (ac *AccessController) can(userAc *reqCtx.UserAccessControl, action, reposi // isAdmin . func (ac *AccessController) isAdmin(username string, userGroups []string) bool { - if common.Contains(ac.Config.AdminPolicy.Users, username) || ac.isAnyGroupInAdminPolicy(userGroups) { + adminPolicy := ac.Config.GetAdminPolicy() + if common.Contains(adminPolicy.Users, username) || ac.isAnyGroupInAdminPolicy(userGroups) { return true } @@ -142,8 +150,9 @@ func (ac *AccessController) isAdmin(username string, userGroups []string) bool { } func (ac *AccessController) isAnyGroupInAdminPolicy(userGroups []string) bool { + adminPolicy := ac.Config.GetAdminPolicy() for _, group := range userGroups { - if common.Contains(ac.Config.AdminPolicy.Groups, group) { + if common.Contains(adminPolicy.Groups, group) { return true } } @@ -154,7 +163,8 @@ func (ac *AccessController) isAnyGroupInAdminPolicy(userGroups []string) bool { func (ac *AccessController) getUserGroups(username string) []string { var groupNames []string - for groupName, group := range ac.Config.Groups { + groups := ac.Config.GetGroups() + for groupName, group := range groups { for _, user := range group.Users { // find if the user is part of any groups if user == username { @@ -241,6 +251,11 @@ func (ac *AccessController) isPermitted(userGroups []string, username, action st func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + // Get configs safely + authConfig := ctlr.Config.CopyAuthConfig() + realm := ctlr.Config.GetRealm() + failDelay := authConfig.GetFailDelay() + /* NOTE: since we only do READ actions in extensions, this middleware is enough for them because it populates the context with user relevant data to be processed by each individual extension @@ -271,7 +286,7 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { // get access control context made in authn.go userAc, err := reqCtx.UserAcFromContext(request.Context()) if err != nil { // should never happen - authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + authFail(response, request, realm, failDelay) return } @@ -287,6 +302,11 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + // Get configs safely + authConfig := ctlr.Config.CopyAuthConfig() + realm := ctlr.Config.GetRealm() + failDelay := authConfig.GetFailDelay() + if request.Method == http.MethodOptions { next.ServeHTTP(response, request) @@ -310,7 +330,7 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { // get userAc built in authn and previous authz middlewares userAc, err := reqCtx.UserAcFromContext(request.Context()) if err != nil { // should never happen - authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + authFail(response, request, realm, failDelay) return } @@ -341,7 +361,7 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { can := acCtrlr.can(userAc, action, resource) //nolint:contextcheck if !can { - common.AuthzFail(response, request, userAc.GetUsername(), ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + common.AuthzFail(response, request, userAc.GetUsername(), realm, failDelay) } else { next.ServeHTTP(response, request) //nolint:contextcheck } @@ -352,17 +372,25 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { func MetricsAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - if ctlr.Config.HTTP.AccessControl == nil { + // Get configs safely + authConfig := ctlr.Config.CopyAuthConfig() + realm := ctlr.Config.GetRealm() + failDelay := authConfig.GetFailDelay() + + accessControlConfig := ctlr.Config.CopyAccessControlConfig() + + if accessControlConfig == nil { // allow access to authenticated user as anonymous policy does not exist next.ServeHTTP(response, request) return } - if len(ctlr.Config.HTTP.AccessControl.Metrics.Users) == 0 { + metricsConfig := accessControlConfig.GetMetrics() + if len(metricsConfig.Users) == 0 { log := ctlr.Log log.Warn().Msg("auth is enabled but no metrics users in accessControl: /metrics is unaccesible") - common.AuthzFail(response, request, "", ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + common.AuthzFail(response, request, "", realm, failDelay) return } @@ -370,14 +398,14 @@ func MetricsAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { // get access control context made in authn.go userAc, err := reqCtx.UserAcFromContext(request.Context()) if err != nil { // should never happen - common.AuthzFail(response, request, "", ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + common.AuthzFail(response, request, "", realm, failDelay) return } username := userAc.GetUsername() - if !common.Contains(ctlr.Config.HTTP.AccessControl.Metrics.Users, username) { - common.AuthzFail(response, request, username, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + if !common.Contains(metricsConfig.Users, username) { + common.AuthzFail(response, request, username, realm, failDelay) return } diff --git a/pkg/api/config/config.go b/pkg/api/config/config.go index c5c9f360..9bb9a469 100644 --- a/pkg/api/config/config.go +++ b/pkg/api/config/config.go @@ -3,9 +3,11 @@ package config import ( "encoding/json" "os" + "sync" "time" distspec "github.com/opencontainers/distribution-spec/specs-go" + "github.com/tiendc/go-deepcopy" "zotregistry.dev/zot/v2/pkg/compat" extconf "zotregistry.dev/zot/v2/pkg/extensions/config" @@ -80,6 +82,59 @@ type AuthConfig struct { SessionDriver map[string]any `mapstructure:",omitempty"` } +// IsLdapAuthEnabled checks if LDAP authentication is enabled in this auth config. +func (a *AuthConfig) IsLdapAuthEnabled() bool { + return a != nil && a.LDAP != nil +} + +// IsHtpasswdAuthEnabled checks if HTPasswd authentication is enabled in this auth config. +func (a *AuthConfig) IsHtpasswdAuthEnabled() bool { + return a != nil && a.HTPasswd.Path != "" +} + +// IsBearerAuthEnabled checks if Bearer authentication is enabled in this auth config. +func (a *AuthConfig) IsBearerAuthEnabled() bool { + return a != nil && a.Bearer != nil && a.Bearer.Cert != "" && a.Bearer.Realm != "" && a.Bearer.Service != "" +} + +// IsOpenIDAuthEnabled checks if OpenID authentication is enabled in this auth config. +func (a *AuthConfig) IsOpenIDAuthEnabled() bool { + if a == nil || a.OpenID == nil { + return false + } + + for provider := range a.OpenID.Providers { + if IsOpenIDSupported(provider) || IsOauth2Supported(provider) { + return true + } + } + + return false +} + +// IsAPIKeyEnabled checks if API Key authentication is enabled in this auth config. +func (a *AuthConfig) IsAPIKeyEnabled() bool { + return a != nil && a.APIKey +} + +// IsBasicAuthnEnabled checks if any basic authentication method is enabled in this auth config. +func (a *AuthConfig) IsBasicAuthnEnabled() bool { + if a == nil { + return false + } + + return a.IsHtpasswdAuthEnabled() || a.IsLdapAuthEnabled() || a.IsOpenIDAuthEnabled() || a.IsAPIKeyEnabled() +} + +// GetFailDelay returns the configured fail delay for authentication attempts. +func (a *AuthConfig) GetFailDelay() int { + if a == nil { + return 0 + } + + return a.FailDelay +} + type BearerConfig struct { Realm string Service string @@ -156,6 +211,11 @@ type ClusterConfig struct { Proxy *ClusterRequestProxyConfig `json:"-" mapstructure:"-"` } +// IsClustered returns true if the cluster configuration represents a multi-node cluster. +func (c *ClusterConfig) IsClustered() bool { + return c != nil && len(c.Members) > 1 +} + type ClusterRequestProxyConfig struct { // holds the cluster socket (IP:port) derived from the host's // interface configuration and the listening port of the HTTP server. @@ -187,20 +247,34 @@ type LDAPConfig struct { } func (ldapConf *LDAPConfig) BindDN() string { + if ldapConf == nil { + return "" + } + return ldapConf.bindDN } func (ldapConf *LDAPConfig) SetBindDN(bindDN string) *LDAPConfig { + if ldapConf == nil { + return nil + } ldapConf.bindDN = bindDN return ldapConf } func (ldapConf *LDAPConfig) BindPassword() string { + if ldapConf == nil { + return "" + } + return ldapConf.bindPassword } func (ldapConf *LDAPConfig) SetBindPassword(bindPassword string) *LDAPConfig { + if ldapConf == nil { + return nil + } ldapConf.bindPassword = bindPassword return ldapConf @@ -224,6 +298,11 @@ type AccessControlConfig struct { Metrics Metrics } +// IsAuthzEnabled checks if authorization is enabled (access control is configured). +func (config *AccessControlConfig) IsAuthzEnabled() bool { + return config != nil +} + func (config *AccessControlConfig) AnonymousPolicyExists() bool { if config == nil { return false @@ -238,6 +317,89 @@ func (config *AccessControlConfig) AnonymousPolicyExists() bool { return false } +// ContainsOnlyAnonymousPolicy checks if the access control configuration contains only anonymous policies. +func (config *AccessControlConfig) ContainsOnlyAnonymousPolicy() bool { + if config == nil { + return true + } + + // Check if admin policy has any actions or users + if len(config.AdminPolicy.Actions)+len(config.AdminPolicy.Users) > 0 { + return false + } + + anonymousPolicyPresent := false + + for _, repository := range config.Repositories { + // Check if repository has default policy + if len(repository.DefaultPolicy) > 0 { + return false + } + + // Check if repository has anonymous policy + if len(repository.AnonymousPolicy) > 0 { + anonymousPolicyPresent = true + } + + // Check if repository has any non-empty policies + for _, policy := range repository.Policies { + if len(policy.Actions)+len(policy.Users) > 0 { + return false + } + } + } + + return anonymousPolicyPresent +} + +// GetRepositories safely gets a copy of the repositories configuration. +func (config *AccessControlConfig) GetRepositories() Repositories { + if config == nil { + return nil + } + + // Return a copy to avoid race conditions + reposCopy := make(Repositories) + for k, v := range config.Repositories { + reposCopy[k] = v + } + + return reposCopy +} + +// GetAdminPolicy safely gets a copy of the admin policy. +func (config *AccessControlConfig) GetAdminPolicy() Policy { + if config == nil { + return Policy{} + } + + return config.AdminPolicy +} + +// GetMetrics safely gets a copy of the metrics configuration. +func (config *AccessControlConfig) GetMetrics() Metrics { + if config == nil { + return Metrics{} + } + + return config.Metrics +} + +// GetGroups safely gets a copy of the groups configuration. +func (config *AccessControlConfig) GetGroups() Groups { + if config == nil { + return nil + } + + // Return a copy to avoid race conditions + groupsCopy := make(Groups) + for k, v := range config.Groups { + groupsCopy[k] = v + } + + return groupsCopy +} + type ( Repositories map[string]PolicyGroup Groups map[string]Group @@ -275,6 +437,9 @@ type Config struct { Extensions *extconf.ExtensionConfig Scheduler *SchedulerConfig `json:"scheduler" mapstructure:",omitempty"` Cluster *ClusterConfig `json:"cluster" mapstructure:",omitempty"` + + // Mutex to protect concurrent access to config fields + mu sync.RWMutex } func New() *Config { @@ -303,35 +468,107 @@ func (expConfig StorageConfig) ParamsEqual(actConfig StorageConfig) bool { expConfig.GCDelay == actConfig.GCDelay && expConfig.GCInterval == actConfig.GCInterval } -// SameFile compare two files. -// This method will first do the stat of two file and compare using os.SameFile method. -func SameFile(str1, str2 string) (bool, error) { - sFile, err := os.Stat(str1) - if err != nil { - return false, err +// isRetentionEnabledInternal checks if retention is enabled without acquiring a lock (internal use only). +func (c *Config) isRetentionEnabledInternal() bool { + if c == nil { + return false } - tFile, err := os.Stat(str2) - if err != nil { - return false, err + var needsMetaDB bool + + for _, retentionPolicy := range c.Storage.Retention.Policies { + for _, tagRetentionPolicy := range retentionPolicy.KeepTags { + if c.isTagsRetentionEnabled(tagRetentionPolicy) { + needsMetaDB = true + } + } } - return os.SameFile(sFile, tFile), nil + for _, subpath := range c.Storage.SubPaths { + for _, retentionPolicy := range subpath.Retention.Policies { + for _, tagRetentionPolicy := range retentionPolicy.KeepTags { + if c.isTagsRetentionEnabled(tagRetentionPolicy) { + needsMetaDB = true + } + } + } + } + + return needsMetaDB } -func DeepCopy(src, dst interface{}) error { - bytes, err := json.Marshal(src) - if err != nil { - return err +// isTagsRetentionEnabled checks if tags retention is enabled for a specific policy (internal use only). +func (c *Config) isTagsRetentionEnabled(tagRetentionPolicy KeepTagsPolicy) bool { + if tagRetentionPolicy.MostRecentlyPulledCount != 0 || + tagRetentionPolicy.MostRecentlyPushedCount != 0 || + tagRetentionPolicy.PulledWithin != nil || + tagRetentionPolicy.PushedWithin != nil { + return true } - err = json.Unmarshal(bytes, dst) + return false +} - return err +// isBasicAuthnEnabled checks if any basic authentication method is enabled (internal, no locking). +func (c *Config) isBasicAuthnEnabled() bool { + if c == nil { + return false + } + + // Check HTPasswd + if c.HTTP.Auth != nil && c.HTTP.Auth.HTPasswd.Path != "" { + return true + } + + // Check LDAP + if c.HTTP.Auth != nil && c.HTTP.Auth.LDAP != nil { + return true + } + + // Check API Key + if c.HTTP.Auth != nil && c.HTTP.Auth.APIKey { + return true + } + + // Check OpenID + if c.HTTP.Auth != nil && c.HTTP.Auth.OpenID != nil { + for provider := range c.HTTP.Auth.OpenID.Providers { + if isOpenIDAuthProviderEnabled(c, provider) { + return true + } + } + } + + return false +} + +// isOpenIDAuthProviderEnabled checks if a specific OpenID provider is enabled (internal use only). +func isOpenIDAuthProviderEnabled(config *Config, provider string) bool { + if providerConfig, ok := config.HTTP.Auth.OpenID.Providers[provider]; ok { + if IsOpenIDSupported(provider) { + if providerConfig.ClientID != "" || providerConfig.Issuer != "" || + len(providerConfig.Scopes) > 0 { + return true + } + } else if IsOauth2Supported(provider) { + if providerConfig.ClientID != "" || len(providerConfig.Scopes) > 0 { + return true + } + } + } + + return false } // Sanitize makes a sanitized copy of the config removing any secrets. func (c *Config) Sanitize() *Config { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + sanitizedConfig := &Config{} if err := DeepCopy(c, sanitizedConfig); err != nil { @@ -370,7 +607,7 @@ func (c *Config) Sanitize() *Config { } } - if c.IsEventRecorderEnabled() { + if c.Extensions.IsEventRecorderEnabled() { for i, sink := range c.Extensions.Events.Sinks { if sink.Credentials == nil { continue @@ -387,24 +624,367 @@ func (c *Config) Sanitize() *Config { return sanitizedConfig } -func (c *Config) IsLdapAuthEnabled() bool { - if c.HTTP.Auth != nil && c.HTTP.Auth.LDAP != nil { - return true +// UpdateReloadableConfig updates only the fields that can be reloaded at runtime. +func (c *Config) UpdateReloadableConfig(newConfig *Config) { + if c == nil { + return } - return false + c.mu.Lock() + defer c.mu.Unlock() + + // Update storage configuration + c.Storage.GC = newConfig.Storage.GC + c.Storage.Dedupe = newConfig.Storage.Dedupe + c.Storage.GCDelay = newConfig.Storage.GCDelay + c.Storage.GCInterval = newConfig.Storage.GCInterval + + // Only update retention if we have a metaDB already in place + if c.isRetentionEnabledInternal() { + c.Storage.Retention = newConfig.Storage.Retention + } + + // Update storage subpaths configuration + for subPath, storageConfig := range newConfig.Storage.SubPaths { + subPathConfig, ok := c.Storage.SubPaths[subPath] + if !ok { + continue + } + + subPathConfig.GC = storageConfig.GC + subPathConfig.Dedupe = storageConfig.Dedupe + subPathConfig.GCDelay = storageConfig.GCDelay + subPathConfig.GCInterval = storageConfig.GCInterval + + // Only update retention if we have a metaDB already in place + if c.isRetentionEnabledInternal() { + subPathConfig.Retention = storageConfig.Retention + } + + c.Storage.SubPaths[subPath] = subPathConfig + } + + // Update authentication configuration + if c.HTTP.Auth != nil && newConfig.HTTP.Auth != nil { + c.HTTP.Auth.HTPasswd = newConfig.HTTP.Auth.HTPasswd + c.HTTP.Auth.LDAP = newConfig.HTTP.Auth.LDAP + c.HTTP.Auth.APIKey = newConfig.HTTP.Auth.APIKey + c.HTTP.Auth.OpenID = newConfig.HTTP.Auth.OpenID + } + + // Initialize and update AccessControlConfig + if newConfig.HTTP.AccessControl != nil && c.HTTP.AccessControl == nil { + c.HTTP.AccessControl = &AccessControlConfig{} + } + + if newConfig.HTTP.AccessControl == nil { + c.HTTP.AccessControl = nil + } else { + // Update AccessControlConfig fields directly + c.HTTP.AccessControl.Repositories = newConfig.HTTP.AccessControl.Repositories + c.HTTP.AccessControl.AdminPolicy = newConfig.HTTP.AccessControl.AdminPolicy + c.HTTP.AccessControl.Metrics = newConfig.HTTP.AccessControl.Metrics + c.HTTP.AccessControl.Groups = newConfig.HTTP.AccessControl.Groups + } + + // Initialize and update ExtensionConfig + if newConfig.Extensions != nil && c.Extensions == nil { + c.Extensions = &extconf.ExtensionConfig{} + } + + if newConfig.Extensions == nil { + c.Extensions = nil + } else if c.Extensions != nil { + // Update sync extension + c.Extensions.Sync = newConfig.Extensions.Sync + + // Update search extension + if newConfig.Extensions.Search != nil && newConfig.Extensions.Search.CVE != nil { + // Only update if search is enabled + if c.Extensions.IsSearchEnabled() { + if c.Extensions.Search != nil { + c.Extensions.Search.CVE = newConfig.Extensions.Search.CVE + } + } + } else { + // Remove search CVE config if not present in new config + if c.Extensions.Search != nil { + c.Extensions.Search.CVE = nil + } + } + + // Update scrub extension + c.Extensions.Scrub = newConfig.Extensions.Scrub + } } -func (c *Config) IsAuthzEnabled() bool { - return c.HTTP.AccessControl != nil +// CopyAuthConfig returns a copy of the auth config if it exists. +func (c *Config) CopyAuthConfig() *AuthConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.HTTP.Auth == nil { + return nil + } + + // Return a deep copy using tiendc/go-deepcopy to avoid race conditions + authCopy := &AuthConfig{} + _ = deepcopy.Copy(authCopy, c.HTTP.Auth) + + return authCopy } +// CopyAccessControlConfig returns a copy of the access control config if it exists. +func (c *Config) CopyAccessControlConfig() *AccessControlConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.HTTP.AccessControl == nil { + return nil + } + + // Return a deep copy using tiendc/go-deepcopy to avoid race conditions + accessControlCopy := &AccessControlConfig{} + _ = deepcopy.Copy(accessControlCopy, c.HTTP.AccessControl) + + return accessControlCopy +} + +// CopyStorageConfig returns a copy of the storage config. +func (c *Config) CopyStorageConfig() GlobalStorageConfig { + if c == nil { + return GlobalStorageConfig{} + } + + c.mu.RLock() + defer c.mu.RUnlock() + + // Return a deep copy using tiendc/go-deepcopy to avoid race conditions + storageCopy := GlobalStorageConfig{} + _ = deepcopy.Copy(&storageCopy, &c.Storage) + + return storageCopy +} + +// CopyExtensionsConfig returns a copy of the extensions config if it exists. +func (c *Config) CopyExtensionsConfig() *extconf.ExtensionConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.Extensions == nil { + return nil + } + + // Return a deep copy using tiendc/go-deepcopy to avoid race conditions + extensionsCopy := &extconf.ExtensionConfig{} + _ = deepcopy.Copy(extensionsCopy, c.Extensions) + + return extensionsCopy +} + +// CopyLogConfig returns a copy of the log config if it exists. +func (c *Config) CopyLogConfig() *LogConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.Log == nil { + return nil + } + + // Return a copy to avoid race conditions + logCopy := *c.Log + + return &logCopy +} + +// CopyClusterConfig returns a copy of the cluster config if it exists. +func (c *Config) CopyClusterConfig() *ClusterConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.Cluster == nil { + return nil + } + + // Return a deep copy using tiendc/go-deepcopy to avoid race conditions + clusterCopy := &ClusterConfig{} + _ = deepcopy.Copy(clusterCopy, c.Cluster) + + return clusterCopy +} + +// CopySchedulerConfig returns a copy of the scheduler config if it exists. +func (c *Config) CopySchedulerConfig() *SchedulerConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.Scheduler == nil { + return nil + } + + // Return a copy to avoid race conditions + schedulerCopy := *c.Scheduler + + return &schedulerCopy +} + +// CopyTLSConfig returns a copy of the TLS config. +func (c *Config) CopyTLSConfig() *TLSConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.HTTP.TLS == nil { + return nil + } + + // Return a copy to avoid race conditions + tlsCopy := *c.HTTP.TLS + + return &tlsCopy +} + +// CopyRatelimit returns a copy of the rate limit config. +func (c *Config) CopyRatelimit() *RatelimitConfig { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.HTTP.Ratelimit == nil { + return nil + } + + // Return a deep copy using tiendc/go-deepcopy to avoid race conditions + ratelimitCopy := &RatelimitConfig{} + _ = deepcopy.Copy(ratelimitCopy, c.HTTP.Ratelimit) + + return ratelimitCopy +} + +// GetVersionInfo returns version information (read-only, safe to access directly). +func (c *Config) GetVersionInfo() (string, string, string, string) { + if c == nil { + return "", "", "", "" + } + + c.mu.RLock() + defer c.mu.RUnlock() + + return c.Commit, c.BinaryType, c.GoVersion, c.DistSpecVersion +} + +// GetRealm returns the HTTP realm value. +func (c *Config) GetRealm() string { + if c == nil { + return "" + } + + c.mu.RLock() + defer c.mu.RUnlock() + + return c.HTTP.Realm +} + +// GetCompat returns a copy of the compatibility config. +func (c *Config) GetCompat() []compat.MediaCompatibility { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.HTTP.Compat == nil { + return nil + } + + // Return a copy to avoid race conditions + compatCopy := make([]compat.MediaCompatibility, len(c.HTTP.Compat)) + copy(compatCopy, c.HTTP.Compat) + + return compatCopy +} + +// GetHTTPAddress returns the HTTP address. +func (c *Config) GetHTTPAddress() string { + if c == nil { + return "" + } + + c.mu.RLock() + defer c.mu.RUnlock() + + return c.HTTP.Address +} + +// GetHTTPPort returns the HTTP port. +func (c *Config) GetHTTPPort() string { + if c == nil { + return "" + } + + c.mu.RLock() + defer c.mu.RUnlock() + + return c.HTTP.Port +} + +// GetAllowOrigin returns the CORS allow origin configuration. +func (c *Config) GetAllowOrigin() string { + if c == nil { + return "" + } + + c.mu.RLock() + defer c.mu.RUnlock() + + return c.HTTP.AllowOrigin +} + +// IsMTLSAuthEnabled checks if mTLS authentication is enabled. func (c *Config) IsMTLSAuthEnabled() bool { + if c == nil { + return false + } + + c.mu.RLock() + defer c.mu.RUnlock() + if c.HTTP.TLS != nil && c.HTTP.TLS.Key != "" && c.HTTP.TLS.Cert != "" && c.HTTP.TLS.CACert != "" && - !c.IsBasicAuthnEnabled() && + !c.isBasicAuthnEnabled() && !c.HTTP.AccessControl.AnonymousPolicyExists() { return true } @@ -412,157 +992,31 @@ func (c *Config) IsMTLSAuthEnabled() bool { return false } -func (c *Config) IsHtpasswdAuthEnabled() bool { - if c.HTTP.Auth != nil && c.HTTP.Auth.HTPasswd.Path != "" { - return true - } - - return false -} - -func (c *Config) IsBearerAuthEnabled() bool { - if c.HTTP.Auth != nil && - c.HTTP.Auth.Bearer != nil && - c.HTTP.Auth.Bearer.Cert != "" && - c.HTTP.Auth.Bearer.Realm != "" && - c.HTTP.Auth.Bearer.Service != "" { - return true - } - - return false -} - -func (c *Config) IsOpenIDAuthEnabled() bool { - if c.HTTP.Auth != nil && - c.HTTP.Auth.OpenID != nil { - for provider := range c.HTTP.Auth.OpenID.Providers { - if isOpenIDAuthProviderEnabled(c, provider) { - return true - } - } - } - - return false -} - -func (c *Config) IsAPIKeyEnabled() bool { - if c.HTTP.Auth != nil && c.HTTP.Auth.APIKey { - return true - } - - return false -} - -func (c *Config) IsBasicAuthnEnabled() bool { - if c.IsHtpasswdAuthEnabled() || c.IsLdapAuthEnabled() || - c.IsOpenIDAuthEnabled() || c.IsAPIKeyEnabled() { - return true - } - - return false -} - -func isOpenIDAuthProviderEnabled(config *Config, provider string) bool { - if providerConfig, ok := config.HTTP.Auth.OpenID.Providers[provider]; ok { - if IsOpenIDSupported(provider) { - if providerConfig.ClientID != "" || providerConfig.Issuer != "" || - len(providerConfig.Scopes) > 0 { - return true - } - } else if IsOauth2Supported(provider) { - if providerConfig.ClientID != "" || len(providerConfig.Scopes) > 0 { - return true - } - } - } - - return false -} - -func (c *Config) IsMetricsEnabled() bool { - return c.Extensions != nil && c.Extensions.Metrics != nil && *c.Extensions.Metrics.Enable -} - -func (c *Config) IsSearchEnabled() bool { - return c.Extensions != nil && c.Extensions.Search != nil && *c.Extensions.Search.Enable -} - -func (c *Config) IsCveScanningEnabled() bool { - return c.IsSearchEnabled() && c.Extensions.Search.CVE != nil -} - -func (c *Config) IsUIEnabled() bool { - return c.Extensions != nil && c.Extensions.UI != nil && *c.Extensions.UI.Enable -} - -func (c *Config) AreUserPrefsEnabled() bool { - return c.IsSearchEnabled() && c.IsUIEnabled() -} - -func (c *Config) IsMgmtEnabled() bool { - return c.IsSearchEnabled() -} - -func (c *Config) IsImageTrustEnabled() bool { - return c.Extensions != nil && c.Extensions.Trust != nil && *c.Extensions.Trust.Enable -} - -// check if tags retention is enabled. +// IsRetentionEnabled checks if tags retention is enabled. func (c *Config) IsRetentionEnabled() bool { - var needsMetaDB bool - - for _, retentionPolicy := range c.Storage.Retention.Policies { - for _, tagRetentionPolicy := range retentionPolicy.KeepTags { - if c.isTagsRetentionEnabled(tagRetentionPolicy) { - needsMetaDB = true - } - } + if c == nil { + return false } - for _, subpath := range c.Storage.SubPaths { - for _, retentionPolicy := range subpath.Retention.Policies { - for _, tagRetentionPolicy := range retentionPolicy.KeepTags { - if c.isTagsRetentionEnabled(tagRetentionPolicy) { - needsMetaDB = true - } - } - } - } + c.mu.RLock() + defer c.mu.RUnlock() - return needsMetaDB -} - -func (c *Config) isTagsRetentionEnabled(tagRetentionPolicy KeepTagsPolicy) bool { - if tagRetentionPolicy.MostRecentlyPulledCount != 0 || - tagRetentionPolicy.MostRecentlyPushedCount != 0 || - tagRetentionPolicy.PulledWithin != nil || - tagRetentionPolicy.PushedWithin != nil { - return true - } - - return false -} - -func (c *Config) IsCosignEnabled() bool { - return c.IsImageTrustEnabled() && c.Extensions.Trust.Cosign -} - -func (c *Config) IsNotationEnabled() bool { - return c.IsImageTrustEnabled() && c.Extensions.Trust.Notation -} - -func (c *Config) IsSyncEnabled() bool { - return c.Extensions != nil && c.Extensions.Sync != nil && *c.Extensions.Sync.Enable + return c.isRetentionEnabledInternal() } +// IsCompatEnabled checks if compatibility mode is enabled. func (c *Config) IsCompatEnabled() bool { + if c == nil { + return false + } + + c.mu.RLock() + defer c.mu.RUnlock() + return len(c.HTTP.Compat) > 0 } -func (c *Config) IsEventRecorderEnabled() bool { - return c.Extensions != nil && c.Extensions.Events != nil && *c.Extensions.Events.Enable -} - +// IsOpenIDSupported checks if the provider supports OpenID. func IsOpenIDSupported(provider string) bool { for _, supportedProvider := range openIDSupportedProviders { if supportedProvider == provider { @@ -573,6 +1027,7 @@ func IsOpenIDSupported(provider string) bool { return false } +// IsOauth2Supported checks if the provider supports OAuth2. func IsOauth2Supported(provider string) bool { for _, supportedProvider := range oauth2SupportedProviders { if supportedProvider == provider { @@ -582,3 +1037,31 @@ func IsOauth2Supported(provider string) bool { return false } + +// SameFile compare two files. +// This method will first do the stat of two file and compare using os.SameFile method. +func SameFile(str1, str2 string) (bool, error) { + sFile, err := os.Stat(str1) + if err != nil { + return false, err + } + + tFile, err := os.Stat(str2) + if err != nil { + return false, err + } + + return os.SameFile(sFile, tFile), nil +} + +// DeepCopy performs a deep copy of src into dst using JSON marshaling/unmarshaling. +func DeepCopy(src, dst interface{}) error { + bytes, err := json.Marshal(src) + if err != nil { + return err + } + + err = json.Unmarshal(bytes, dst) + + return err +} diff --git a/pkg/api/config/config_test.go b/pkg/api/config/config_test.go index 8eaaabd6..b5127300 100644 --- a/pkg/api/config/config_test.go +++ b/pkg/api/config/config_test.go @@ -7,8 +7,10 @@ import ( . "github.com/smartystreets/goconvey/convey" "zotregistry.dev/zot/v2/pkg/api/config" + "zotregistry.dev/zot/v2/pkg/compat" extconf "zotregistry.dev/zot/v2/pkg/extensions/config" - "zotregistry.dev/zot/v2/pkg/extensions/config/events" + eventsconf "zotregistry.dev/zot/v2/pkg/extensions/config/events" + syncconf "zotregistry.dev/zot/v2/pkg/extensions/config/sync" ) func TestConfig(t *testing.T) { @@ -69,26 +71,333 @@ func TestConfig(t *testing.T) { }) Convey("Test DeepCopy() & Sanitize()", t, func() { - conf := config.New() - So(conf, ShouldNotBeNil) + Convey("Test DeepCopy negative cases", func() { + conf := config.New() + So(conf, ShouldNotBeNil) - authConfig := &config.AuthConfig{LDAP: (&config.LDAPConfig{}).SetBindPassword("oina")} - conf.HTTP.Auth = authConfig + // negative + obj := make(chan int) + err := config.DeepCopy(conf, obj) + So(err, ShouldNotBeNil) + err = config.DeepCopy(obj, conf) + So(err, ShouldNotBeNil) + }) - So(func() { conf.Sanitize() }, ShouldNotPanic) + Convey("Test Sanitize() with LDAP bind password", func() { + conf := config.New() + So(conf, ShouldNotBeNil) - conf = conf.Sanitize() - So(conf.HTTP.Auth.LDAP.BindPassword(), ShouldEqual, "******") + // Set LDAP bind password + authConfig := &config.AuthConfig{LDAP: (&config.LDAPConfig{}).SetBindPassword("secret-ldap-password")} + conf.HTTP.Auth = authConfig - // negative - obj := make(chan int) - err := config.DeepCopy(conf, obj) - So(err, ShouldNotBeNil) - err = config.DeepCopy(obj, conf) - So(err, ShouldNotBeNil) + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + So(sanitizedConf.HTTP.Auth.LDAP.BindPassword(), ShouldEqual, "******") + + // Verify original config is not modified + So(conf.HTTP.Auth.LDAP.BindPassword(), ShouldEqual, "secret-ldap-password") + }) + + Convey("Test Sanitize() with OpenID client secrets", func() { + conf := config.New() + So(conf, ShouldNotBeNil) + + // Set OpenID client secrets + authConfig := &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + Name: "Google", + ClientID: "google-client-id", + ClientSecret: "google-client-secret", + Issuer: "https://accounts.google.com", + Scopes: []string{"openid", "email"}, + }, + "github": { + Name: "GitHub", + ClientID: "github-client-id", + ClientSecret: "github-client-secret", + Scopes: []string{"user:email"}, + }, + }, + }, + } + conf.HTTP.Auth = authConfig + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + + // Verify OpenID client secrets are sanitized + So(sanitizedConf.HTTP.Auth.OpenID.Providers["google"].ClientSecret, ShouldEqual, "******") + So(sanitizedConf.HTTP.Auth.OpenID.Providers["github"].ClientSecret, ShouldEqual, "******") + + // Verify other fields are preserved + So(sanitizedConf.HTTP.Auth.OpenID.Providers["google"].ClientID, ShouldEqual, "google-client-id") + So(sanitizedConf.HTTP.Auth.OpenID.Providers["google"].Name, ShouldEqual, "Google") + So(sanitizedConf.HTTP.Auth.OpenID.Providers["google"].Issuer, ShouldEqual, "https://accounts.google.com") + So(sanitizedConf.HTTP.Auth.OpenID.Providers["google"].Scopes, ShouldResemble, []string{"openid", "email"}) + + // Verify original config is not modified + So(conf.HTTP.Auth.OpenID.Providers["google"].ClientSecret, ShouldEqual, "google-client-secret") + So(conf.HTTP.Auth.OpenID.Providers["github"].ClientSecret, ShouldEqual, "github-client-secret") + }) + + Convey("Test Sanitize() with Event sink credentials", func() { + conf := config.New() + So(conf, ShouldNotBeNil) + + // Enable events extension and set sink credentials + enabled := true + conf.Extensions = &extconf.ExtensionConfig{ + Events: &eventsconf.Config{ + Enable: &enabled, + Sinks: []eventsconf.SinkConfig{ + { + Type: eventsconf.HTTP, + Address: "https://example.com/webhook", + Credentials: &eventsconf.Credentials{ + Username: "webhook-user", + Password: "webhook-password", + }, + }, + { + Type: eventsconf.NATS, + Address: "nats://localhost:4222", + Credentials: &eventsconf.Credentials{ + Username: "nats-user", + Password: "nats-token", + }, + }, + }, + }, + } + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + + // Verify event sink credentials passwords are sanitized + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Password, ShouldEqual, "******") + So(sanitizedConf.Extensions.Events.Sinks[1].Credentials.Password, ShouldEqual, "******") + + // Verify other fields are preserved + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Username, ShouldEqual, "webhook-user") + So(sanitizedConf.Extensions.Events.Sinks[1].Credentials.Username, ShouldEqual, "nats-user") + So(sanitizedConf.Extensions.Events.Sinks[0].Type, ShouldEqual, eventsconf.HTTP) + So(sanitizedConf.Extensions.Events.Sinks[1].Type, ShouldEqual, eventsconf.NATS) + + // Verify original config is not modified + So(conf.Extensions.Events.Sinks[0].Credentials.Password, ShouldEqual, "webhook-password") + So(conf.Extensions.Events.Sinks[1].Credentials.Password, ShouldEqual, "nats-token") + }) + + Convey("Test Sanitize() with Event sink credentials including nil credentials", func() { + conf := config.New() + So(conf, ShouldNotBeNil) + + // Enable events extension with mixed sink credentials (some nil, some not) + enabled := true + conf.Extensions = &extconf.ExtensionConfig{ + Events: &eventsconf.Config{ + Enable: &enabled, + Sinks: []eventsconf.SinkConfig{ + { + Type: eventsconf.HTTP, + Address: "https://example.com/webhook", + Credentials: &eventsconf.Credentials{ + Username: "webhook-user", + Password: "webhook-password", + }, + }, + { + Type: eventsconf.NATS, + Address: "nats://localhost:4222", + Credentials: nil, // This should trigger the continue statement + }, + { + Type: eventsconf.HTTP, + Address: "https://another.com/webhook", + Credentials: &eventsconf.Credentials{ + Username: "another-user", + Password: "another-password", + }, + }, + }, + }, + } + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + + // Verify that sinks with credentials have their passwords sanitized + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Password, ShouldEqual, "******") + So(sanitizedConf.Extensions.Events.Sinks[2].Credentials.Password, ShouldEqual, "******") + + // Verify that sink with nil credentials is preserved as-is (no panic, no modification) + So(sanitizedConf.Extensions.Events.Sinks[1].Credentials, ShouldBeNil) + So(sanitizedConf.Extensions.Events.Sinks[1].Type, ShouldEqual, eventsconf.NATS) + So(sanitizedConf.Extensions.Events.Sinks[1].Address, ShouldEqual, "nats://localhost:4222") + + // Verify other fields are preserved + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Username, ShouldEqual, "webhook-user") + So(sanitizedConf.Extensions.Events.Sinks[2].Credentials.Username, ShouldEqual, "another-user") + So(sanitizedConf.Extensions.Events.Sinks[0].Type, ShouldEqual, eventsconf.HTTP) + So(sanitizedConf.Extensions.Events.Sinks[2].Type, ShouldEqual, eventsconf.HTTP) + + // Verify original config is not modified + So(conf.Extensions.Events.Sinks[0].Credentials.Password, ShouldEqual, "webhook-password") + So(conf.Extensions.Events.Sinks[2].Credentials.Password, ShouldEqual, "another-password") + So(conf.Extensions.Events.Sinks[1].Credentials, ShouldBeNil) + }) + + Convey("Test Sanitize() with all sensitive data types", func() { + conf := config.New() + So(conf, ShouldNotBeNil) + + // Set all types of sensitive data + authConfig := &config.AuthConfig{ + LDAP: (&config.LDAPConfig{}).SetBindPassword("ldap-secret"), + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "azure": { + Name: "Azure AD", + ClientID: "azure-client-id", + ClientSecret: "azure-client-secret", + Issuer: "https://login.microsoftonline.com/...", + }, + }, + }, + } + conf.HTTP.Auth = authConfig + + // Enable events extension + enabled := true + conf.Extensions = &extconf.ExtensionConfig{ + Events: &eventsconf.Config{ + Enable: &enabled, + Sinks: []eventsconf.SinkConfig{ + { + Type: eventsconf.HTTP, + Address: "https://smtp.example.com/webhook", + Credentials: &eventsconf.Credentials{ + Username: "smtp-user", + Password: "smtp-password", + }, + }, + }, + }, + } + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + + // Verify all sensitive data is sanitized + So(sanitizedConf.HTTP.Auth.LDAP.BindPassword(), ShouldEqual, "******") + So(sanitizedConf.HTTP.Auth.OpenID.Providers["azure"].ClientSecret, ShouldEqual, "******") + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Password, ShouldEqual, "******") + + // Verify non-sensitive data is preserved + So(sanitizedConf.HTTP.Auth.OpenID.Providers["azure"].ClientID, ShouldEqual, "azure-client-id") + So(sanitizedConf.HTTP.Auth.OpenID.Providers["azure"].Name, ShouldEqual, "Azure AD") + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Username, ShouldEqual, "smtp-user") + So(sanitizedConf.Extensions.Events.Sinks[0].Type, ShouldEqual, eventsconf.HTTP) + }) + + Convey("Test Sanitize() with nil sensitive data", func() { + conf := config.New() + So(conf, ShouldNotBeNil) + + // Set config with nil sensitive data + authConfig := &config.AuthConfig{ + LDAP: nil, // No LDAP config + OpenID: nil, // No OpenID config + } + conf.HTTP.Auth = authConfig + + // No events extension + conf.Extensions = nil + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + + // Verify nil configs are handled gracefully + So(sanitizedConf.HTTP.Auth.LDAP, ShouldBeNil) + So(sanitizedConf.HTTP.Auth.OpenID, ShouldBeNil) + So(sanitizedConf.Extensions, ShouldBeNil) + }) + + Convey("Test Sanitize() with empty sensitive data", func() { + conf := config.New() + So(conf, ShouldNotBeNil) + + // Set config with empty sensitive data + authConfig := &config.AuthConfig{ + LDAP: (&config.LDAPConfig{}).SetBindPassword(""), // Empty password + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "empty": { + Name: "Empty Provider", + ClientID: "empty-client-id", + ClientSecret: "", // Empty secret + }, + }, + }, + } + conf.HTTP.Auth = authConfig + + // Enable events extension with empty password + enabled := true + conf.Extensions = &extconf.ExtensionConfig{ + Events: &eventsconf.Config{ + Enable: &enabled, + Sinks: []eventsconf.SinkConfig{ + { + Type: eventsconf.HTTP, + Address: "https://example.com/webhook", + Credentials: &eventsconf.Credentials{ + Username: "user", + Password: "", // Empty password + }, + }, + }, + }, + } + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + + // Verify empty passwords behavior + // LDAP empty password should remain empty + So(sanitizedConf.HTTP.Auth.LDAP.BindPassword(), ShouldEqual, "") + // OpenID empty secret is always sanitized + So(sanitizedConf.HTTP.Auth.OpenID.Providers["empty"].ClientSecret, ShouldEqual, "******") + // Event sink empty password is always sanitized + So(sanitizedConf.Extensions.Events.Sinks[0].Credentials.Password, ShouldEqual, "******") + }) + + Convey("Test Sanitize() with nil config", func() { + var conf *config.Config = nil + + So(func() { conf.Sanitize() }, ShouldNotPanic) + + sanitizedConf := conf.Sanitize() + So(sanitizedConf, ShouldBeNil) + }) }) Convey("Test IsRetentionEnabled()", t, func() { + // Test nil config + var nilConf *config.Config = nil + + So(nilConf.IsRetentionEnabled(), ShouldBeFalse) + conf := config.New() So(conf.IsRetentionEnabled(), ShouldBeFalse) @@ -130,31 +439,2960 @@ func TestConfig(t *testing.T) { conf.Storage.SubPaths = subPaths So(conf.IsRetentionEnabled(), ShouldBeTrue) + + // Test MostRecentlyPushedCount + conf = config.New() + conf.Storage.Retention.Policies = []config.RetentionPolicy{ + { + Repositories: []string{"repo"}, + KeepTags: []config.KeepTagsPolicy{ + { + Patterns: []string{"tag"}, + MostRecentlyPushedCount: 3, + }, + }, + }, + } + So(conf.IsRetentionEnabled(), ShouldBeTrue) + + // Test PulledWithin + conf = config.New() + duration := time.Hour * 24 + conf.Storage.Retention.Policies = []config.RetentionPolicy{ + { + Repositories: []string{"repo"}, + KeepTags: []config.KeepTagsPolicy{ + { + Patterns: []string{"tag"}, + PulledWithin: &duration, + }, + }, + }, + } + So(conf.IsRetentionEnabled(), ShouldBeTrue) + + // Test PushedWithin + conf = config.New() + conf.Storage.Retention.Policies = []config.RetentionPolicy{ + { + Repositories: []string{"repo"}, + KeepTags: []config.KeepTagsPolicy{ + { + Patterns: []string{"tag"}, + PushedWithin: &duration, + }, + }, + }, + } + So(conf.IsRetentionEnabled(), ShouldBeTrue) + + // Test SubPaths with retention policies + conf = config.New() + conf.Storage.SubPaths = map[string]config.StorageConfig{ + "subpath1": { + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"repo1"}, + KeepTags: []config.KeepTagsPolicy{ + { + Patterns: []string{"latest"}, + MostRecentlyPulledCount: 5, + }, + }, + }, + }, + }, + }, + } + So(conf.IsRetentionEnabled(), ShouldBeTrue) + + // Test empty policies with no retention criteria + conf = config.New() + conf.Storage.Retention.Policies = []config.RetentionPolicy{ + { + Repositories: []string{"repo"}, + KeepTags: []config.KeepTagsPolicy{ + { + Patterns: []string{"tag"}, + // No retention criteria set + }, + }, + }, + } + So(conf.IsRetentionEnabled(), ShouldBeFalse) }) Convey("Test IsEventRecorderEnabled()", t, func() { conf := config.New() - So(conf.IsEventRecorderEnabled(), ShouldBeFalse) + extensionsConfig := conf.CopyExtensionsConfig() + So(extensionsConfig.IsEventRecorderEnabled(), ShouldBeFalse) // Enable the event recorder enable := true conf.Extensions = &extconf.ExtensionConfig{} - conf.Extensions.Events = &events.Config{ + conf.Extensions.Events = &eventsconf.Config{ Enable: &enable, } - So(conf.IsEventRecorderEnabled(), ShouldBeTrue) + extensionsConfig = conf.CopyExtensionsConfig() + So(extensionsConfig.IsEventRecorderEnabled(), ShouldBeTrue) // Disabled scenario disable := false conf.Extensions.Events.Enable = &disable - So(conf.IsEventRecorderEnabled(), ShouldBeFalse) + extensionsConfig = conf.CopyExtensionsConfig() + So(extensionsConfig.IsEventRecorderEnabled(), ShouldBeFalse) // nil pointers conf.Extensions.Events = nil - So(conf.IsEventRecorderEnabled(), ShouldBeFalse) + extensionsConfig = conf.CopyExtensionsConfig() + So(extensionsConfig.IsEventRecorderEnabled(), ShouldBeFalse) conf.Extensions = nil - So(conf.IsEventRecorderEnabled(), ShouldBeFalse) + extensionsConfig = conf.CopyExtensionsConfig() + So(extensionsConfig.IsEventRecorderEnabled(), ShouldBeFalse) + }) + + Convey("Test AccessControlConfig.ContainsOnlyAnonymousPolicy()", t, func() { + Convey("When accessControlConfig is nil", func() { + var accessControlConfig *config.AccessControlConfig = nil + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeTrue) + }) + + Convey("When accessControlConfig has admin policies", func() { + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.AdminPolicy = config.Policy{ + Actions: []string{"read"}, + Users: []string{"admin"}, + } + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeFalse) + }) + + Convey("When accessControlConfig has only anonymous policies", func() { + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Repositories = config.Repositories{ + "repo1": config.PolicyGroup{ + AnonymousPolicy: []string{"read"}, + }, + } + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeTrue) + }) + + Convey("When accessControlConfig has default policies", func() { + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Repositories = config.Repositories{ + "repo1": config.PolicyGroup{ + DefaultPolicy: []string{"read"}, + }, + } + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeFalse) + }) + + Convey("When accessControlConfig has non-empty repository policies", func() { + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Repositories = config.Repositories{ + "repo1": config.PolicyGroup{ + Policies: []config.Policy{ + { + Actions: []string{"read"}, + Users: []string{"user1"}, + }, + }, + }, + } + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeFalse) + }) + + Convey("When accessControlConfig has empty admin policy and no repositories", func() { + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.AdminPolicy = config.Policy{ + Actions: []string{}, + Users: []string{}, + } + accessControlConfig.Repositories = config.Repositories{} + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeFalse) + }) + + Convey("When accessControlConfig has empty policies in repository", func() { + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Repositories = config.Repositories{ + "repo1": config.PolicyGroup{ + AnonymousPolicy: []string{"read"}, + Policies: []config.Policy{ + { + Actions: []string{}, + Users: []string{}, + }, + }, + }, + } + + result := accessControlConfig.ContainsOnlyAnonymousPolicy() + So(result, ShouldBeTrue) + }) + }) + + Convey("Test AuthConfig methods", t, func() { + Convey("Test IsLdapAuthEnabled()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.IsLdapAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig but nil LDAP + authConfig = &config.AuthConfig{} + So(authConfig.IsLdapAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig and LDAP configured + authConfig = &config.AuthConfig{ + LDAP: &config.LDAPConfig{}, + } + So(authConfig.IsLdapAuthEnabled(), ShouldBeTrue) + }) + + Convey("Test IsHtpasswdAuthEnabled()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig but empty HTPasswd path + authConfig = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{Path: ""}, + } + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig and HTPasswd configured + authConfig = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{Path: "/path/to/htpasswd"}, + } + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) + }) + + Convey("Test IsBearerAuthEnabled()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.IsBearerAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig but nil Bearer + authConfig = &config.AuthConfig{} + So(authConfig.IsBearerAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig and Bearer configured with all required fields + authConfig = &config.AuthConfig{ + Bearer: &config.BearerConfig{ + Cert: "/path/to/cert.pem", + Realm: "test-realm", + Service: "test-service", + }, + } + So(authConfig.IsBearerAuthEnabled(), ShouldBeTrue) + }) + + Convey("Test IsOpenIDAuthEnabled()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.IsOpenIDAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig but nil OpenID + authConfig = &config.AuthConfig{} + So(authConfig.IsOpenIDAuthEnabled(), ShouldBeFalse) + + // Test with AuthConfig and OpenID configured with providers + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "client-id", + }, + }, + }, + } + So(authConfig.IsOpenIDAuthEnabled(), ShouldBeTrue) + }) + + Convey("Test IsAPIKeyEnabled()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) + + // Test with AuthConfig but APIKey disabled + authConfig = &config.AuthConfig{ + APIKey: false, + } + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) + + // Test with AuthConfig and APIKey enabled + authConfig = &config.AuthConfig{ + APIKey: true, + } + So(authConfig.IsAPIKeyEnabled(), ShouldBeTrue) + }) + + Convey("Test IsBasicAuthnEnabled()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.IsBasicAuthnEnabled(), ShouldBeFalse) + + // Test with AuthConfig but no basic auth methods + authConfig = &config.AuthConfig{} + So(authConfig.IsBasicAuthnEnabled(), ShouldBeFalse) + + // Test with HTPasswd enabled + authConfig = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{Path: "/path/to/htpasswd"}, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with LDAP enabled + authConfig = &config.AuthConfig{ + LDAP: &config.LDAPConfig{}, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with OpenID enabled (with ClientID) + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "client-id", + Scopes: []string{"openid", "email"}, + }, + }, + }, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with OpenID enabled (with Issuer) + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "https://accounts.google.com", + Scopes: []string{}, + }, + }, + }, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with OpenID enabled (with Scopes only) + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "", + Scopes: []string{"openid", "email"}, + }, + }, + }, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with OAuth2 provider (github) + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "github": { + ClientID: "github-client-id", + Scopes: []string{"user:email"}, + }, + }, + }, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with OpenID but no valid providers (empty config) + // Note: AuthConfig.IsOpenIDAuthEnabled() only checks if provider is supported, + // not if the configuration is valid, so this returns true + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "", + Scopes: []string{}, + }, + }, + }, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + + // Test with OpenID but unsupported provider + authConfig = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "unsupported": { + ClientID: "client-id", + Scopes: []string{"scope"}, + }, + }, + }, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeFalse) + + // Test with APIKey enabled + authConfig = &config.AuthConfig{ + APIKey: true, + } + So(authConfig.IsBasicAuthnEnabled(), ShouldBeTrue) + }) + + Convey("Test GetFailDelay()", func() { + // Test with nil AuthConfig + var authConfig *config.AuthConfig = nil + + So(authConfig.GetFailDelay(), ShouldEqual, 0) + + // Test with AuthConfig and custom FailDelay + authConfig = &config.AuthConfig{ + FailDelay: 5, + } + So(authConfig.GetFailDelay(), ShouldEqual, 5) + }) + }) + + Convey("Test LDAPConfig methods", t, func() { + Convey("Test BindDN()", func() { + ldapConfig := &config.LDAPConfig{} + So(ldapConfig.BindDN(), ShouldEqual, "") + + ldapConfig.SetBindDN("cn=admin,dc=example,dc=com") + So(ldapConfig.BindDN(), ShouldEqual, "cn=admin,dc=example,dc=com") + }) + + Convey("Test BindPassword()", func() { + ldapConfig := &config.LDAPConfig{} + So(ldapConfig.BindPassword(), ShouldEqual, "") + + ldapConfig.SetBindPassword("secretpassword") + So(ldapConfig.BindPassword(), ShouldEqual, "secretpassword") + }) + }) + + Convey("Test AccessControlConfig methods", t, func() { + Convey("Test IsAuthzEnabled()", func() { + // Test with nil AccessControlConfig + var accessControlConfig *config.AccessControlConfig = nil + + So(accessControlConfig.IsAuthzEnabled(), ShouldBeFalse) + + // Test with AccessControlConfig + accessControlConfig = &config.AccessControlConfig{} + So(accessControlConfig.IsAuthzEnabled(), ShouldBeTrue) + }) + + Convey("Test AnonymousPolicyExists()", func() { + // Test with nil AccessControlConfig + var accessControlConfig *config.AccessControlConfig = nil + + So(accessControlConfig.AnonymousPolicyExists(), ShouldBeFalse) + + // Test with AccessControlConfig but no repositories + accessControlConfig = &config.AccessControlConfig{} + So(accessControlConfig.AnonymousPolicyExists(), ShouldBeFalse) + + // Test with AccessControlConfig and repository with anonymous policy + accessControlConfig = &config.AccessControlConfig{} + accessControlConfig.Repositories = config.Repositories{ + "repo1": config.PolicyGroup{ + AnonymousPolicy: []string{"read"}, + }, + } + So(accessControlConfig.AnonymousPolicyExists(), ShouldBeTrue) + + // Test with AccessControlConfig and repository without anonymous policy + accessControlConfig = &config.AccessControlConfig{} + accessControlConfig.Repositories = config.Repositories{ + "repo1": config.PolicyGroup{ + DefaultPolicy: []string{"read"}, + }, + } + So(accessControlConfig.AnonymousPolicyExists(), ShouldBeFalse) + }) + + Convey("Test GetRepositories()", func() { + repositories := config.Repositories{ + "repo1": config.PolicyGroup{ + AnonymousPolicy: []string{"read"}, + }, + } + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Repositories = repositories + So(accessControlConfig.GetRepositories(), ShouldResemble, repositories) + }) + + Convey("Test GetAdminPolicy()", func() { + adminPolicy := config.Policy{ + Actions: []string{"read", "write"}, + Users: []string{"admin"}, + } + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.AdminPolicy = adminPolicy + So(accessControlConfig.GetAdminPolicy(), ShouldResemble, adminPolicy) + }) + + Convey("Test GetMetrics()", func() { + metrics := config.Metrics{ + Users: []string{"metrics-user"}, + } + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Metrics = metrics + So(accessControlConfig.GetMetrics(), ShouldResemble, metrics) + }) + + Convey("Test GetGroups()", func() { + groups := config.Groups{ + "developers": config.Group{ + Users: []string{"dev1", "dev2"}, + }, + } + accessControlConfig := &config.AccessControlConfig{} + accessControlConfig.Groups = groups + So(accessControlConfig.GetGroups(), ShouldResemble, groups) + }) + }) + + Convey("Test Config getter methods", t, func() { + Convey("Test CopyAuthConfig()", func() { + Convey("Test with non-nil Auth", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + FailDelay: 5, + }, + }, + } + authConfig := cfg.CopyAuthConfig() + So(authConfig, ShouldNotBeNil) + So(authConfig.GetFailDelay(), ShouldEqual, 5) + }) + + Convey("Test with nil Auth", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: nil, + }, + } + authConfig := cfg.CopyAuthConfig() + So(authConfig, ShouldBeNil) + }) + + Convey("Test that returned AuthConfig is isolated from config mutations", func() { + // Create initial config with AuthConfig containing nested structures + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + FailDelay: 5, + HTPasswd: config.AuthHTPasswd{ + Path: "/etc/htpasswd", + }, + LDAP: &config.LDAPConfig{ + Address: "ldap.example.com", + Port: 389, + }, + Bearer: &config.BearerConfig{ + Realm: "test-realm", + Service: "test-service", + Cert: "/path/to/cert", + }, + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + Name: "Google", + ClientID: "google-client-id", + Scopes: []string{"openid", "email"}, + }, + }, + }, + APIKey: false, + SessionKeysFile: "/etc/session-keys", + SessionHashKey: []byte("hash-key"), + SessionEncryptKey: []byte("encrypt-key"), + SessionDriver: map[string]any{ + "type": "redis", + "host": "localhost", + }, + }, + }, + } + + // Get the AuthConfig reference + authConfig := cfg.CopyAuthConfig() + So(authConfig, ShouldNotBeNil) + So(authConfig.GetFailDelay(), ShouldEqual, 5) + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) + So(authConfig.IsLdapAuthEnabled(), ShouldBeTrue) + So(authConfig.IsBearerAuthEnabled(), ShouldBeTrue) + So(authConfig.IsOpenIDAuthEnabled(), ShouldBeTrue) + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) + + // Test deep copy isolation by modifying nested structures + authConfig.LDAP.Address = "modified-ldap.example.com" + authConfig.Bearer.Realm = "modified-realm" + authConfig.OpenID.Providers["google"].Scopes[0] = "modified-scope" + authConfig.SessionHashKey[0] = 'M' + authConfig.SessionDriver["type"] = "modified-driver" + + // Verify original is unchanged + So(cfg.HTTP.Auth.LDAP.Address, ShouldEqual, "ldap.example.com") + So(cfg.HTTP.Auth.Bearer.Realm, ShouldEqual, "test-realm") + So(cfg.HTTP.Auth.OpenID.Providers["google"].Scopes[0], ShouldEqual, "openid") + So(cfg.HTTP.Auth.SessionHashKey[0], ShouldEqual, byte('h')) + So(cfg.HTTP.Auth.SessionDriver["type"], ShouldEqual, "redis") + }) + + Convey("Test that returned AuthConfig is isolated when config is updated via UpdateReloadableConfig", func() { + // Create initial config with AuthConfig + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + FailDelay: 5, + HTPasswd: config.AuthHTPasswd{ + Path: "/etc/htpasswd", + }, + APIKey: false, + }, + }, + } + + // Get the AuthConfig reference + authConfig := cfg.CopyAuthConfig() + So(authConfig, ShouldNotBeNil) + So(authConfig.GetFailDelay(), ShouldEqual, 5) + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) + + // Create new config with updated AuthConfig + // Note: UpdateReloadableConfig updates HTPasswd, LDAP, APIKey, and OpenID fields + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + FailDelay: 15, // This field is NOT updated by UpdateReloadableConfig + HTPasswd: config.AuthHTPasswd{ + Path: "/etc/updated-htpasswd", // This field IS updated by UpdateReloadableConfig + }, + APIKey: true, // This field IS updated by UpdateReloadableConfig + }, + }, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that the returned AuthConfig is not affected by the update + // CopyAuthConfig() returns a copy, so the returned object should be isolated + So(authConfig.GetFailDelay(), ShouldEqual, 5) // Should remain unchanged + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) // Should remain unchanged (old path) + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) // Should remain unchanged + + // Verify that a new CopyAuthConfig() call returns the updated values + newAuthConfig := cfg.CopyAuthConfig() + So(newAuthConfig, ShouldNotBeNil) + // Should remain unchanged (not updated by UpdateReloadableConfig) + So(newAuthConfig.GetFailDelay(), ShouldEqual, 5) + So(newAuthConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) // Should be updated (new path) + // Should be updated by UpdateReloadableConfig + So(newAuthConfig.IsAPIKeyEnabled(), ShouldBeTrue) + }) + + Convey("Test that returned AuthConfig is isolated when config is set to nil", func() { + // Create initial config with AuthConfig + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + FailDelay: 5, + HTPasswd: config.AuthHTPasswd{ + Path: "/etc/htpasswd", + }, + APIKey: false, + }, + }, + } + + // Get the AuthConfig reference + authConfig := cfg.CopyAuthConfig() + So(authConfig, ShouldNotBeNil) + So(authConfig.GetFailDelay(), ShouldEqual, 5) + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) + + // Set the AuthConfig to nil + cfg.HTTP.Auth = nil + + // Verify that the returned AuthConfig is not affected by setting to nil + So(authConfig, ShouldNotBeNil) // Should remain unchanged + So(authConfig.GetFailDelay(), ShouldEqual, 5) // Should remain unchanged + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeTrue) // Should remain unchanged + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) // Should remain unchanged + + // Verify that a new CopyAuthConfig() call returns nil + newAuthConfig := cfg.CopyAuthConfig() + So(newAuthConfig, ShouldBeNil) // Should be nil + }) + }) + + Convey("Test CopyAccessControlConfig()", func() { + Convey("Test with non-nil AccessControl", func() { + testAccessControlConfig := &config.AccessControlConfig{ + Repositories: config.Repositories{ + "repo1": config.PolicyGroup{ + Policies: []config.Policy{ + { + Users: []string{"user1", "user2"}, + Actions: []string{"read", "write"}, + Groups: []string{"group1"}, + }, + }, + DefaultPolicy: []string{"read"}, + AnonymousPolicy: []string{"read"}, + }, + }, + AdminPolicy: config.Policy{ + Users: []string{"admin1"}, + Actions: []string{"read", "write", "delete"}, + Groups: []string{"admin-group"}, + }, + Groups: config.Groups{ + "group1": config.Group{ + Users: []string{"user1", "user2"}, + }, + }, + Metrics: config.Metrics{ + Users: []string{"metrics-user"}, + }, + } + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: testAccessControlConfig, + }, + } + accessControlConfig := cfg.CopyAccessControlConfig() + So(accessControlConfig, ShouldNotBeNil) + So(accessControlConfig.IsAuthzEnabled(), ShouldBeTrue) + + // Test deep copy isolation + accessControlConfig.Repositories["repo1"].Policies[0].Users[0] = "modified-user" + accessControlConfig.Repositories["repo1"].DefaultPolicy[0] = "modified-policy" + accessControlConfig.AdminPolicy.Users[0] = "modified-admin" + accessControlConfig.Groups["group1"].Users[0] = "modified-group-user" + accessControlConfig.Metrics.Users[0] = "modified-metrics-user" + + // Verify original is unchanged + So(cfg.HTTP.AccessControl.Repositories["repo1"].Policies[0].Users[0], ShouldEqual, "user1") + So(cfg.HTTP.AccessControl.Repositories["repo1"].DefaultPolicy[0], ShouldEqual, "read") + So(cfg.HTTP.AccessControl.AdminPolicy.Users[0], ShouldEqual, "admin1") + So(cfg.HTTP.AccessControl.Groups["group1"].Users[0], ShouldEqual, "user1") + So(cfg.HTTP.AccessControl.Metrics.Users[0], ShouldEqual, "metrics-user") + }) + + Convey("Test with nil AccessControl", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: nil, + }, + } + accessControlConfig := cfg.CopyAccessControlConfig() + So(accessControlConfig, ShouldBeNil) + }) + }) + + Convey("Test CopyStorageConfig()", func() { + Convey("Test with non-nil Storage", func() { + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + RootDirectory: "/tmp/storage", + GC: true, + }, + }, + } + storageConfig := cfg.CopyStorageConfig() + So(storageConfig, ShouldNotBeNil) + So(storageConfig.RootDirectory, ShouldEqual, "/tmp/storage") + So(storageConfig.GC, ShouldBeTrue) + }) + + Convey("Test with nil Storage", func() { + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{}, + } + storageConfig := cfg.CopyStorageConfig() + So(storageConfig, ShouldNotBeNil) // GlobalStorageConfig is a struct, not a pointer, so it's never nil + So(storageConfig.RootDirectory, ShouldEqual, "") + So(storageConfig.GC, ShouldBeFalse) + }) + + Convey("Test StorageConfig deep copy isolation", func() { + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + RootDirectory: "/tmp/storage", + GC: true, + Retention: config.ImageRetention{ + DryRun: true, + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"repo1", "repo2"}, + KeepTags: []config.KeepTagsPolicy{ + { + Patterns: []string{"pattern1", "pattern2"}, + }, + }, + }, + }, + }, + StorageDriver: map[string]interface{}{ + "type": "filesystem", + }, + CacheDriver: map[string]interface{}{ + "type": "redis", + }, + }, + SubPaths: map[string]config.StorageConfig{ + "/subpath1": { + RootDirectory: "/tmp/subpath1", + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"subrepo1"}, + }, + }, + }, + StorageDriver: map[string]interface{}{ + "type": "s3", + }, + }, + }, + }, + } + + // Get a copy of the storage config + storageConfig := cfg.CopyStorageConfig() + So(storageConfig, ShouldNotBeNil) + + // Mutate the copy's fields + storageConfig.RootDirectory = "/modified/storage" + storageConfig.GC = false + storageConfig.Retention.Policies[0].Repositories[0] = "modified-repo" + storageConfig.Retention.Policies[0].KeepTags[0].Patterns[0] = "modified-pattern" + storageConfig.StorageDriver["type"] = "modified-driver" + storageConfig.CacheDriver["type"] = "modified-cache" + + // Mutate SubPaths by getting a copy, modifying it, and putting it back + subPathConfig := storageConfig.SubPaths["/subpath1"] + subPathConfig.RootDirectory = "/modified/subpath1" + subPathConfig.Retention.Policies[0].Repositories[0] = "modified-subrepo" + subPathConfig.StorageDriver["type"] = "modified-s3" + storageConfig.SubPaths["/subpath1"] = subPathConfig + + // Verify original config is unchanged + So(cfg.Storage.RootDirectory, ShouldEqual, "/tmp/storage") + So(cfg.Storage.GC, ShouldBeTrue) + So(cfg.Storage.Retention.Policies[0].Repositories[0], ShouldEqual, "repo1") + So(cfg.Storage.Retention.Policies[0].KeepTags[0].Patterns[0], ShouldEqual, "pattern1") + So(cfg.Storage.StorageDriver["type"], ShouldEqual, "filesystem") + So(cfg.Storage.CacheDriver["type"], ShouldEqual, "redis") + So(cfg.Storage.SubPaths["/subpath1"].RootDirectory, ShouldEqual, "/tmp/subpath1") + So(cfg.Storage.SubPaths["/subpath1"].Retention.Policies[0].Repositories[0], ShouldEqual, "subrepo1") + So(cfg.Storage.SubPaths["/subpath1"].StorageDriver["type"], ShouldEqual, "s3") + + // Verify copy has the mutations + So(storageConfig.RootDirectory, ShouldEqual, "/modified/storage") + So(storageConfig.GC, ShouldBeFalse) + So(storageConfig.Retention.Policies[0].Repositories[0], ShouldEqual, "modified-repo") + So(storageConfig.Retention.Policies[0].KeepTags[0].Patterns[0], ShouldEqual, "modified-pattern") + So(storageConfig.StorageDriver["type"], ShouldEqual, "modified-driver") + So(storageConfig.CacheDriver["type"], ShouldEqual, "modified-cache") + So(storageConfig.SubPaths["/subpath1"].RootDirectory, ShouldEqual, "/modified/subpath1") + So(storageConfig.SubPaths["/subpath1"].Retention.Policies[0].Repositories[0], ShouldEqual, "modified-subrepo") + So(storageConfig.SubPaths["/subpath1"].StorageDriver["type"], ShouldEqual, "modified-s3") + }) + }) + + Convey("Test CopyLogConfig()", func() { + Convey("Test with non-nil Log", func() { + cfg := &config.Config{ + Log: &config.LogConfig{ + Level: "info", + Output: "/tmp/logs", + }, + } + logConfig := cfg.CopyLogConfig() + So(logConfig, ShouldNotBeNil) + So(logConfig.Level, ShouldEqual, "info") + So(logConfig.Output, ShouldEqual, "/tmp/logs") + }) + + Convey("Test with nil Log", func() { + cfg := &config.Config{ + Log: nil, + } + logConfig := cfg.CopyLogConfig() + So(logConfig, ShouldBeNil) + }) + }) + + Convey("Test CopyClusterConfig()", func() { + Convey("Test with non-nil Cluster", func() { + cfg := &config.Config{ + Cluster: &config.ClusterConfig{ + Members: []string{"node1", "node2"}, + }, + } + clusterConfig := cfg.CopyClusterConfig() + So(clusterConfig, ShouldNotBeNil) + So(len(clusterConfig.Members), ShouldEqual, 2) + }) + + Convey("Test with nil Cluster", func() { + cfg := &config.Config{ + Cluster: nil, + } + clusterConfig := cfg.CopyClusterConfig() + So(clusterConfig, ShouldBeNil) + }) + + Convey("Test ClusterConfig deep copy isolation", func() { + cfg := &config.Config{ + Cluster: &config.ClusterConfig{ + Members: []string{"node1", "node2"}, + HashKey: "test-key", + TLS: &config.TLSConfig{ + Cert: "test-cert", + Key: "test-key", + CACert: "test-ca", + }, + Proxy: &config.ClusterRequestProxyConfig{ + LocalMemberClusterSocket: "127.0.0.1:8080", + LocalMemberClusterSocketIndex: 1, + }, + }, + } + + // Get a copy of the cluster config + clusterConfig := cfg.CopyClusterConfig() + So(clusterConfig, ShouldNotBeNil) + + // Mutate the copy + clusterConfig.Members[0] = "modified-node" + clusterConfig.HashKey = "modified-key" + clusterConfig.TLS.Cert = "modified-cert" + clusterConfig.Proxy.LocalMemberClusterSocket = "modified-socket" + + // Verify original config is unchanged + So(cfg.Cluster.Members[0], ShouldEqual, "node1") + So(cfg.Cluster.HashKey, ShouldEqual, "test-key") + So(cfg.Cluster.TLS.Cert, ShouldEqual, "test-cert") + So(cfg.Cluster.Proxy.LocalMemberClusterSocket, ShouldEqual, "127.0.0.1:8080") + + // Verify copy has the mutations + So(clusterConfig.Members[0], ShouldEqual, "modified-node") + So(clusterConfig.HashKey, ShouldEqual, "modified-key") + So(clusterConfig.TLS.Cert, ShouldEqual, "modified-cert") + So(clusterConfig.Proxy.LocalMemberClusterSocket, ShouldEqual, "modified-socket") + }) + }) + + Convey("Test CopySchedulerConfig()", func() { + Convey("Test with non-nil Scheduler", func() { + cfg := &config.Config{ + Scheduler: &config.SchedulerConfig{ + NumWorkers: 4, + }, + } + schedulerConfig := cfg.CopySchedulerConfig() + So(schedulerConfig, ShouldNotBeNil) + So(schedulerConfig.NumWorkers, ShouldEqual, 4) + }) + + Convey("Test with nil Scheduler", func() { + cfg := &config.Config{ + Scheduler: nil, + } + schedulerConfig := cfg.CopySchedulerConfig() + So(schedulerConfig, ShouldBeNil) + }) + }) + + Convey("Test GetVersionInfo()", func() { + Convey("Test with non-nil version info", func() { + cfg := &config.Config{ + Commit: "abc123", + BinaryType: "server", + GoVersion: "go1.21", + DistSpecVersion: "1.1.1", + } + commit, binaryType, goVersion, distSpecVersion := cfg.GetVersionInfo() + So(commit, ShouldEqual, "abc123") + So(binaryType, ShouldEqual, "server") + So(goVersion, ShouldEqual, "go1.21") + So(distSpecVersion, ShouldEqual, "1.1.1") + }) + + Convey("Test with empty version info", func() { + cfg := &config.Config{ + Commit: "", + BinaryType: "", + GoVersion: "", + DistSpecVersion: "", + } + commit, binaryType, goVersion, distSpecVersion := cfg.GetVersionInfo() + So(commit, ShouldEqual, "") + So(binaryType, ShouldEqual, "") + So(goVersion, ShouldEqual, "") + So(distSpecVersion, ShouldEqual, "") + }) + }) + + Convey("Test GetRealm()", func() { + Convey("Test with non-empty Realm", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Realm: "my-realm", + }, + } + realm := cfg.GetRealm() + So(realm, ShouldEqual, "my-realm") + }) + + Convey("Test with empty Realm", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Realm: "", + }, + } + realm := cfg.GetRealm() + So(realm, ShouldEqual, "") + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + realm := cfg.GetRealm() + So(realm, ShouldEqual, "") + }) + }) + + Convey("Test CopyTLSConfig()", func() { + Convey("Test with non-empty TLS config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca.pem", + }, + }, + } + tlsConfig := cfg.CopyTLSConfig() + So(tlsConfig, ShouldNotBeNil) + So(tlsConfig.Cert, ShouldEqual, "/path/to/cert.pem") + So(tlsConfig.Key, ShouldEqual, "/path/to/key.pem") + So(tlsConfig.CACert, ShouldEqual, "/path/to/ca.pem") + + // Test copy isolation + tlsConfig.Cert = "/modified/cert.pem" + + So(cfg.HTTP.TLS.Cert, ShouldEqual, "/path/to/cert.pem") + }) + + Convey("Test with nil TLS config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + TLS: nil, + }, + } + tlsConfig := cfg.CopyTLSConfig() + So(tlsConfig, ShouldBeNil) + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + tlsConfig := cfg.CopyTLSConfig() + So(tlsConfig, ShouldBeNil) + }) + }) + + Convey("Test GetCompat()", func() { + Convey("Test with non-empty compat config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Compat: []compat.MediaCompatibility{ + "docker2s2", + "oci1", + }, + }, + } + compatConfig := cfg.GetCompat() + So(compatConfig, ShouldNotBeNil) + So(len(compatConfig), ShouldEqual, 2) + So(string(compatConfig[0]), ShouldEqual, "docker2s2") + So(string(compatConfig[1]), ShouldEqual, "oci1") + + // Test copy isolation + compatConfig[0] = "modified-compat" + + So(string(cfg.HTTP.Compat[0]), ShouldEqual, "docker2s2") + }) + + Convey("Test with nil compat config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Compat: nil, + }, + } + compatConfig := cfg.GetCompat() + So(compatConfig, ShouldBeNil) + }) + + Convey("Test with empty compat config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Compat: []compat.MediaCompatibility{}, + }, + } + compatConfig := cfg.GetCompat() + So(compatConfig, ShouldNotBeNil) + So(len(compatConfig), ShouldEqual, 0) + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + compatConfig := cfg.GetCompat() + So(compatConfig, ShouldBeNil) + }) + }) + + Convey("Test GetHTTPAddress()", func() { + Convey("Test with non-empty address", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Address: "192.168.1.100", + }, + } + address := cfg.GetHTTPAddress() + So(address, ShouldEqual, "192.168.1.100") + }) + + Convey("Test with empty address", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Address: "", + }, + } + address := cfg.GetHTTPAddress() + So(address, ShouldEqual, "") + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + address := cfg.GetHTTPAddress() + So(address, ShouldEqual, "") + }) + }) + + Convey("Test GetHTTPPort()", func() { + Convey("Test with non-empty port", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Port: "8080", + }, + } + port := cfg.GetHTTPPort() + So(port, ShouldEqual, "8080") + }) + + Convey("Test with empty port", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Port: "", + }, + } + port := cfg.GetHTTPPort() + So(port, ShouldEqual, "") + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + port := cfg.GetHTTPPort() + So(port, ShouldEqual, "") + }) + }) + + Convey("Test GetAllowOrigin()", func() { + Convey("Test with non-empty allow origin", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AllowOrigin: "http://localhost:3000,https://example.com", + }, + } + allowOrigin := cfg.GetAllowOrigin() + So(allowOrigin, ShouldEqual, "http://localhost:3000,https://example.com") + }) + + Convey("Test with empty allow origin", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AllowOrigin: "", + }, + } + allowOrigin := cfg.GetAllowOrigin() + So(allowOrigin, ShouldEqual, "") + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + allowOrigin := cfg.GetAllowOrigin() + So(allowOrigin, ShouldEqual, "") + }) + }) + + Convey("Test CopyRatelimit()", func() { + Convey("Test with non-empty ratelimit config", func() { + rate := 100 + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Ratelimit: &config.RatelimitConfig{ + Rate: &rate, + Methods: []config.MethodRatelimitConfig{ + { + Method: "GET", + Rate: 50, + }, + { + Method: "POST", + Rate: 25, + }, + }, + }, + }, + } + ratelimitConfig := cfg.CopyRatelimit() + So(ratelimitConfig, ShouldNotBeNil) + So(*ratelimitConfig.Rate, ShouldEqual, 100) + So(len(ratelimitConfig.Methods), ShouldEqual, 2) + So(ratelimitConfig.Methods[0].Method, ShouldEqual, "GET") + So(ratelimitConfig.Methods[0].Rate, ShouldEqual, 50) + So(ratelimitConfig.Methods[1].Method, ShouldEqual, "POST") + So(ratelimitConfig.Methods[1].Rate, ShouldEqual, 25) + + // Test deep copy isolation + *ratelimitConfig.Rate = 200 + ratelimitConfig.Methods[0].Rate = 75 + ratelimitConfig.Methods[0].Method = "PUT" + + So(*cfg.HTTP.Ratelimit.Rate, ShouldEqual, 100) + So(cfg.HTTP.Ratelimit.Methods[0].Rate, ShouldEqual, 50) + So(cfg.HTTP.Ratelimit.Methods[0].Method, ShouldEqual, "GET") + }) + + Convey("Test with nil ratelimit config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Ratelimit: nil, + }, + } + ratelimitConfig := cfg.CopyRatelimit() + So(ratelimitConfig, ShouldBeNil) + }) + + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + ratelimitConfig := cfg.CopyRatelimit() + So(ratelimitConfig, ShouldBeNil) + }) + }) + }) + + Convey("Test Config utility methods", t, func() { + Convey("Test IsMTLSAuthEnabled()", func() { + // Test with nil Config + var cfg *config.Config = nil + + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) + + // Test with Config but no TLS + cfg = &config.Config{} + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) + + // Test with Config and TLS but no client cert + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) + + // Test with Config and TLS with CA cert (mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) + + // Test with HTPasswd enabled (should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{ + Path: "/path/to/htpasswd", + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with LDAP enabled (should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: &config.LDAPConfig{}, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with API Key enabled (should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + APIKey: true, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with OpenID enabled (valid config - should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "client-id", + Issuer: "", + Scopes: []string{}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with OpenID enabled (with Issuer - should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "https://accounts.google.com", + Scopes: []string{}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with OpenID enabled (with Scopes - should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "", + Scopes: []string{"openid", "email"}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with OAuth2 provider (github) with ClientID (should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "github": { + ClientID: "github-client-id", + Scopes: []string{}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with OAuth2 provider (github) with Scopes (should disable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "github": { + ClientID: "", + Scopes: []string{"user:email"}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled + + // Test with OpenID but empty config (should enable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "", + Scopes: []string{}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) // No basic auth, so mTLS enabled + + // Test with OpenID but unsupported provider (should enable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "unsupported": { + ClientID: "client-id", + Scopes: []string{"scope"}, + }, + }, + }, + }, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) // No basic auth, so mTLS enabled + + // Test with no authentication methods (should enable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{}, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) // No basic auth, so mTLS enabled + + // Test with nil Auth (should enable mTLS) + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Auth: nil, + TLS: &config.TLSConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + CACert: "/path/to/ca-cert.pem", + }, + }, + } + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) // No basic auth, so mTLS enabled + }) + + Convey("Test IsCompatEnabled()", func() { + // Test with nil Config + var cfg *config.Config = nil + + So(cfg.IsCompatEnabled(), ShouldBeFalse) + + // Test with Config but no Compat + cfg = &config.Config{} + So(cfg.IsCompatEnabled(), ShouldBeFalse) + + // Test with Config and Compat enabled + cfg = &config.Config{ + HTTP: config.HTTPConfig{ + Compat: []compat.MediaCompatibility{compat.DockerManifestV2SchemaV2}, + }, + } + So(cfg.IsCompatEnabled(), ShouldBeTrue) + }) + + Convey("Test IsOpenIDSupported()", func() { + // Test with unsupported provider + So(config.IsOpenIDSupported("unsupported"), ShouldBeFalse) + + // Test with supported provider + So(config.IsOpenIDSupported("google"), ShouldBeTrue) + }) + + Convey("Test IsOauth2Supported()", func() { + // Test with unsupported provider + So(config.IsOauth2Supported("unsupported"), ShouldBeFalse) + + // Test with supported provider + So(config.IsOauth2Supported("github"), ShouldBeTrue) + }) + + Convey("Test IsClustered() with nil ClusterConfig", func() { + var clusterConfig *config.ClusterConfig = nil + + So(clusterConfig.IsClustered(), ShouldBeFalse) + }) + + Convey("Test IsClustered() with empty members", func() { + clusterConfig := &config.ClusterConfig{ + Members: []string{}, + } + So(clusterConfig.IsClustered(), ShouldBeFalse) + }) + + Convey("Test IsClustered() with single member", func() { + clusterConfig := &config.ClusterConfig{ + Members: []string{"node1:8080"}, + } + So(clusterConfig.IsClustered(), ShouldBeFalse) + }) + + Convey("Test IsClustered() with multiple members", func() { + clusterConfig := &config.ClusterConfig{ + Members: []string{"node1:8080", "node2:8080"}, + } + So(clusterConfig.IsClustered(), ShouldBeTrue) + }) + }) + + Convey("Test CopyExtensionsConfig methods", t, func() { + Convey("Test IsSearchEnabled()", func() { + // Test with nil Config + var cfg *config.Config = nil + + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeFalse) + + // Test with Config but nil Extensions + cfg = &config.Config{} + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeFalse) + + // Test with Config and Extensions but nil Search + cfg = &config.Config{ + Extensions: &extconf.ExtensionConfig{}, + } + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeFalse) + + // Test with Config and Extensions and Search but disabled + disabled := false + cfg = &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &disabled, + }, + }, + }, + } + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeFalse) + + // Test with Config and Extensions and Search enabled + enabled := true + cfg = &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + }, + } + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeTrue) + }) + }) + + Convey("Test UpdateReloadableConfig()", t, func() { + Convey("Test with nil Config", func() { + var cfg *config.Config = nil + newConfig := &config.Config{} + + So(func() { cfg.UpdateReloadableConfig(newConfig) }, ShouldNotPanic) + }) + + Convey("Test with nil newConfig.HTTP.Auth", func() { + // Create initial config with Auth + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + FailDelay: 5, + HTPasswd: config.AuthHTPasswd{ + Path: "/etc/htpasswd", + }, + APIKey: false, + }, + }, + } + + // Create new config with nil Auth + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: nil, // This should not cause a panic + }, + } + + // This should not panic even though newConfig.HTTP.Auth is nil + So(func() { cfg.UpdateReloadableConfig(newConfig) }, ShouldNotPanic) + + // Verify that the original Auth config remains unchanged + So(cfg.HTTP.Auth, ShouldNotBeNil) + So(cfg.HTTP.Auth.FailDelay, ShouldEqual, 5) + So(cfg.HTTP.Auth.HTPasswd.Path, ShouldEqual, "/etc/htpasswd") + So(cfg.HTTP.Auth.APIKey, ShouldBeFalse) + }) + + Convey("Test with AccessControl update", func() { + cfgAccessControl := &config.AccessControlConfig{} + cfgAccessControl.AdminPolicy = config.Policy{ + Actions: []string{"read"}, + } + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: cfgAccessControl, + }, + } + newConfigAccessControl := &config.AccessControlConfig{} + newConfigAccessControl.AdminPolicy = config.Policy{ + Actions: []string{"read", "write"}, + } + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: newConfigAccessControl, + }, + } + cfg.UpdateReloadableConfig(newConfig) + So(cfg.CopyAccessControlConfig().GetAdminPolicy().Actions, ShouldResemble, []string{"read", "write"}) + }) + + Convey("Test with Extensions update", func() { + // First set up a config with search enabled + enabled := true + cfg := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + }, + } + + // Create new config with CVE config + newConfig := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + CVE: &extconf.CVEConfig{ + UpdateInterval: time.Hour * 2, + }, + }, + }, + } + cfg.UpdateReloadableConfig(newConfig) + // The search should still be enabled and CVE config should be updated + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeTrue) + }) + + Convey("Test search CVE config removal when new config has nil Search.CVE", func() { + // First set up a config with search enabled and CVE config + enabled := true + cfg := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + CVE: &extconf.CVEConfig{ + UpdateInterval: time.Hour, + }, + }, + }, + } + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeTrue) + So(cfg.Extensions.Search.CVE, ShouldNotBeNil) + + // Create new config with Search but nil CVE + newConfig := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + CVE: nil, // This should trigger the removal + }, + }, + } + cfg.UpdateReloadableConfig(newConfig) + + // Verify that the CVE config was removed + So(cfg.Extensions.Search.CVE, ShouldBeNil) + So(cfg.Extensions.Search.Enable, ShouldNotBeNil) + So(*cfg.Extensions.Search.Enable, ShouldBeTrue) + }) + + Convey("Test search CVE config removal when new config has nil Search", func() { + // First set up a config with search enabled and CVE config + enabled := true + cfg := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + CVE: &extconf.CVEConfig{ + UpdateInterval: time.Hour, + }, + }, + }, + } + So(cfg.CopyExtensionsConfig().IsSearchEnabled(), ShouldBeTrue) + So(cfg.Extensions.Search.CVE, ShouldNotBeNil) + + // Create new config with Extensions but nil Search + newConfig := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: nil, // This should trigger the removal + }, + } + cfg.UpdateReloadableConfig(newConfig) + + // Verify that the CVE config was removed + So(cfg.Extensions.Search.CVE, ShouldBeNil) + So(cfg.Extensions.Search.Enable, ShouldNotBeNil) + So(*cfg.Extensions.Search.Enable, ShouldBeTrue) + }) + }) + + Convey("Test isOpenIDAuthProviderEnabled indirectly via IsMTLSAuthEnabled", t, func() { + Convey("Test with OpenID provider with empty config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + TLS: &config.TLSConfig{ + Key: "key", + Cert: "cert", + CACert: "cacert", + }, + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "", + Issuer: "", + Scopes: []string{}, + }, + }, + }, + }, + AccessControl: &config.AccessControlConfig{}, + }, + } + // This should return true because isOpenIDAuthProviderEnabled returns false for empty config, + // so isBasicAuthnEnabled returns false, making IsMTLSAuthEnabled return true + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) + }) + + Convey("Test with OpenID provider with valid config", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + TLS: &config.TLSConfig{ + Key: "key", + Cert: "cert", + CACert: "cacert", + }, + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "google": { + ClientID: "client-id", + Issuer: "", + Scopes: []string{}, + }, + }, + }, + }, + AccessControl: &config.AccessControlConfig{}, + }, + } + // This should return false because isOpenIDAuthProviderEnabled returns true for valid config, + // so isBasicAuthnEnabled returns true, making IsMTLSAuthEnabled return false + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) + }) + + Convey("Test with unsupported OpenID provider", func() { + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + TLS: &config.TLSConfig{ + Key: "key", + Cert: "cert", + CACert: "cacert", + }, + Auth: &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "unsupported": { + ClientID: "client-id", + Scopes: []string{"scope"}, + }, + }, + }, + }, + AccessControl: &config.AccessControlConfig{}, + }, + } + // This should return true because isOpenIDAuthProviderEnabled returns false for unsupported provider, + // so isBasicAuthnEnabled returns false, making IsMTLSAuthEnabled return true + So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) + }) + }) + + Convey("Test nil receiver coverage for all methods", t, func() { + Convey("Test AuthConfig methods with nil receiver", func() { + var authConfig *config.AuthConfig = nil + + So(authConfig.IsLdapAuthEnabled(), ShouldBeFalse) + So(authConfig.IsHtpasswdAuthEnabled(), ShouldBeFalse) + So(authConfig.IsBearerAuthEnabled(), ShouldBeFalse) + So(authConfig.IsOpenIDAuthEnabled(), ShouldBeFalse) + So(authConfig.IsAPIKeyEnabled(), ShouldBeFalse) + So(authConfig.IsBasicAuthnEnabled(), ShouldBeFalse) + So(authConfig.GetFailDelay(), ShouldEqual, 0) + }) + + Convey("Test LDAPConfig methods with nil receiver", func() { + var ldapConfig *config.LDAPConfig = nil + + So(ldapConfig.BindDN(), ShouldEqual, "") + So(ldapConfig.BindPassword(), ShouldEqual, "") + So(ldapConfig.SetBindDN("test"), ShouldBeNil) + So(ldapConfig.SetBindPassword("test"), ShouldBeNil) + }) + + Convey("Test AccessControlConfig methods with nil receiver", func() { + var accessControlConfig *config.AccessControlConfig = nil + + So(accessControlConfig.IsAuthzEnabled(), ShouldBeFalse) + So(accessControlConfig.AnonymousPolicyExists(), ShouldBeFalse) + So(accessControlConfig.ContainsOnlyAnonymousPolicy(), ShouldBeTrue) + + // Test getter methods + So(accessControlConfig.GetRepositories(), ShouldBeNil) + So(accessControlConfig.GetAdminPolicy(), ShouldResemble, config.Policy{}) + So(accessControlConfig.GetMetrics(), ShouldResemble, config.Metrics{}) + So(accessControlConfig.GetGroups(), ShouldBeNil) + }) + + Convey("Test Config methods with nil receiver", func() { + var cfg *config.Config = nil + + // Test getter methods + So(cfg.CopyAuthConfig(), ShouldBeNil) + So(cfg.CopyAccessControlConfig(), ShouldBeNil) + So(cfg.GetHTTPAddress(), ShouldEqual, "") + So(cfg.GetHTTPPort(), ShouldEqual, "") + So(cfg.GetAllowOrigin(), ShouldEqual, "") + So(cfg.CopyTLSConfig(), ShouldBeNil) + So(cfg.CopyRatelimit(), ShouldBeNil) + So(cfg.GetCompat(), ShouldBeNil) + So(cfg.CopyStorageConfig(), ShouldResemble, config.GlobalStorageConfig{}) + So(cfg.CopyExtensionsConfig(), ShouldBeNil) + So(cfg.CopyLogConfig(), ShouldBeNil) + So(cfg.CopyClusterConfig(), ShouldBeNil) + So(cfg.CopySchedulerConfig(), ShouldBeNil) + + // Test GetVersionInfo + commit, binaryType, goVersion, distSpecVersion := cfg.GetVersionInfo() + So(commit, ShouldEqual, "") + So(binaryType, ShouldEqual, "") + So(goVersion, ShouldEqual, "") + So(distSpecVersion, ShouldEqual, "") + + // Test boolean methods + So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) + So(cfg.IsRetentionEnabled(), ShouldBeFalse) + So(cfg.IsCompatEnabled(), ShouldBeFalse) + + // Test Sanitize + So(cfg.Sanitize(), ShouldBeNil) + + // Test UpdateReloadableConfig (should not panic) + newConfig := &config.Config{} + + So(func() { cfg.UpdateReloadableConfig(newConfig) }, ShouldNotPanic) + }) + }) + + Convey("Test AccessControlConfig copy isolation through CopyAccessControlConfig()", t, func() { + Convey("Test that mutations to retrieved AccessControlConfig copy do not affect original config", func() { + // Create a config with initial AccessControlConfig + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: &config.AccessControlConfig{ + AdminPolicy: config.Policy{ + Actions: []string{"read"}, + Users: []string{"admin"}, + }, + Repositories: config.Repositories{ + "repo1": config.PolicyGroup{ + DefaultPolicy: []string{"read"}, + Policies: []config.Policy{ + { + Actions: []string{"read"}, + }, + }, + }, + }, + }, + }, + } + + // Retrieve the AccessControlConfig (should be a copy) + accessControlConfig := cfg.CopyAccessControlConfig() + So(accessControlConfig, ShouldNotBeNil) + + // Mutate the retrieved AccessControlConfig copy + accessControlConfig.AdminPolicy = config.Policy{ + Actions: []string{"read", "write", "delete"}, + Users: []string{"admin", "superadmin"}, + } + + // Add a new repository to the copy + newRepositories := config.Repositories{ + "repo1": config.PolicyGroup{ + DefaultPolicy: []string{"read"}, + Policies: []config.Policy{ + { + Actions: []string{"read"}, + }, + }, + }, + "repo2": config.PolicyGroup{ + DefaultPolicy: []string{"read", "write"}, + Policies: []config.Policy{ + { + Actions: []string{"read", "write"}, + Users: []string{"user1"}, + }, + }, + }, + } + accessControlConfig.Repositories = newRepositories + + // Verify that the original config is unchanged + originalAccessControlConfig := cfg.CopyAccessControlConfig() + So(originalAccessControlConfig, ShouldNotBeNil) + + // Check that admin policy remains unchanged in original + adminPolicy := originalAccessControlConfig.GetAdminPolicy() + So(adminPolicy.Actions, ShouldResemble, []string{"read"}) + So(adminPolicy.Users, ShouldResemble, []string{"admin"}) + + // Check that repositories remain unchanged in original + repositories := originalAccessControlConfig.GetRepositories() + So(len(repositories), ShouldEqual, 1) + So(repositories["repo1"], ShouldNotBeNil) + So(repositories["repo1"].DefaultPolicy, ShouldResemble, []string{"read"}) + }) + + Convey("Test that mutations to retrieved AccessControlConfig copy work with nil initial config", func() { + // Create a config with nil AccessControlConfig + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: nil, + }, + } + + // Retrieve the AccessControlConfig (should return nil) + accessControlConfig := cfg.CopyAccessControlConfig() + So(accessControlConfig, ShouldBeNil) + + // Create a new AccessControlConfig and set it + newAccessControlConfig := &config.AccessControlConfig{} + newAccessControlConfig.AdminPolicy = config.Policy{ + Actions: []string{"read"}, + Users: []string{"admin"}, + } + + // Manually set the AccessControlConfig on the original config + cfg.HTTP.AccessControl = newAccessControlConfig + + // Now retrieve it again and verify it works + retrievedConfig := cfg.CopyAccessControlConfig() + So(retrievedConfig, ShouldNotBeNil) + + // Mutate the retrieved config copy + retrievedConfig.AdminPolicy = config.Policy{ + Actions: []string{"read", "write"}, + Users: []string{"admin", "user"}, + } + + // Verify the original config is unchanged + finalConfig := cfg.CopyAccessControlConfig() + adminPolicy := finalConfig.GetAdminPolicy() + So(adminPolicy.Actions, ShouldResemble, []string{"read"}) + So(adminPolicy.Users, ShouldResemble, []string{"admin"}) + }) + }) + + Convey("Test AccessControlConfig copy isolation through UpdateReloadableConfig()", t, func() { + Convey("Test that AccessControlConfig copies are isolated from UpdateReloadableConfig changes", func() { + // Create initial config with AccessControlConfig + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: &config.AccessControlConfig{ + AdminPolicy: config.Policy{ + Actions: []string{"read"}, + Users: []string{"admin"}, + }, + Repositories: config.Repositories{ + "repo1": config.PolicyGroup{ + DefaultPolicy: []string{"read"}, + Policies: []config.Policy{ + { + Actions: []string{"read"}, + }, + }, + }, + }, + }, + }, + } + + // Get initial reference to AccessControlConfig + initialAccessControlConfig := cfg.CopyAccessControlConfig() + So(initialAccessControlConfig, ShouldNotBeNil) + + // Verify initial state + initialAdminPolicy := initialAccessControlConfig.GetAdminPolicy() + So(initialAdminPolicy.Actions, ShouldResemble, []string{"read"}) + So(initialAdminPolicy.Users, ShouldResemble, []string{"admin"}) + + initialRepositories := initialAccessControlConfig.GetRepositories() + So(len(initialRepositories), ShouldEqual, 1) + So(initialRepositories["repo1"], ShouldNotBeNil) + + // Create new config with updated AccessControlConfig + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: &config.AccessControlConfig{ + AdminPolicy: config.Policy{ + Actions: []string{"read", "write", "delete"}, + Users: []string{"admin", "superadmin", "user"}, + }, + Repositories: config.Repositories{ + "repo1": config.PolicyGroup{ + DefaultPolicy: []string{"read", "write"}, + Policies: []config.Policy{ + { + Actions: []string{"read", "write"}, + }, + }, + }, + "repo2": config.PolicyGroup{ + DefaultPolicy: []string{"read"}, + Policies: []config.Policy{ + { + Actions: []string{"read"}, + Users: []string{"user1", "user2"}, + }, + }, + }, + }, + }, + }, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that the old copy remains unchanged (copy isolation) + updatedAdminPolicy := initialAccessControlConfig.GetAdminPolicy() + So(updatedAdminPolicy.Actions, ShouldResemble, []string{"read"}) + So(updatedAdminPolicy.Users, ShouldResemble, []string{"admin"}) + + updatedRepositories := initialAccessControlConfig.GetRepositories() + So(len(updatedRepositories), ShouldEqual, 1) + So(updatedRepositories["repo1"], ShouldNotBeNil) + So(updatedRepositories["repo1"].DefaultPolicy, ShouldResemble, []string{"read"}) + + // Verify that a new copy gets the updated data + newAccessControlConfig := cfg.CopyAccessControlConfig() + So(newAccessControlConfig, ShouldNotBeNil) + So(newAccessControlConfig, ShouldNotEqual, initialAccessControlConfig) // Different copy + + newAdminPolicy := newAccessControlConfig.GetAdminPolicy() + So(newAdminPolicy.Actions, ShouldResemble, []string{"read", "write", "delete"}) + So(newAdminPolicy.Users, ShouldResemble, []string{"admin", "superadmin", "user"}) + }) + + Convey("Test that old AccessControlConfig reference works with nil initial config", func() { + // Create config with nil AccessControlConfig + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: nil, + }, + } + + // Get initial reference (should be nil) + initialAccessControlConfig := cfg.CopyAccessControlConfig() + So(initialAccessControlConfig, ShouldBeNil) + + // Create new config with AccessControlConfig + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: &config.AccessControlConfig{ + AdminPolicy: config.Policy{ + Actions: []string{"read", "write"}, + Users: []string{"admin"}, + }, + }, + }, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that a new reference now gets the data + newAccessControlConfig := cfg.CopyAccessControlConfig() + So(newAccessControlConfig, ShouldNotBeNil) + + adminPolicy := newAccessControlConfig.GetAdminPolicy() + So(adminPolicy.Actions, ShouldResemble, []string{"read", "write"}) + So(adminPolicy.Users, ShouldResemble, []string{"admin"}) + }) + + Convey("Test that old AccessControlConfig reference works when new config has nil AccessControlConfig", func() { + // Create initial config with AccessControlConfig + testAccessControlConfig := &config.AccessControlConfig{} + testAccessControlConfig.AdminPolicy = config.Policy{ + Actions: []string{"read"}, + Users: []string{"admin"}, + } + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: testAccessControlConfig, + }, + } + + // Get initial reference + initialAccessControlConfig := cfg.CopyAccessControlConfig() + So(initialAccessControlConfig, ShouldNotBeNil) + + // Create new config with nil AccessControlConfig + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + AccessControl: nil, + }, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that a new reference now returns nil + newAccessControlConfig := cfg.CopyAccessControlConfig() + So(newAccessControlConfig, ShouldBeNil) + }) + }) + + Convey("Test ExtensionConfig copy isolation through CopyExtensionsConfig()", t, func() { + Convey("Test that mutations to retrieved ExtensionConfig copy do not affect original config", func() { + // Create a config with initial ExtensionConfig + enabled := true + cfg := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + CVE: &extconf.CVEConfig{ + UpdateInterval: time.Hour, + Trivy: &extconf.TrivyConfig{ + DBRepository: "original/trivy-db", + }, + }, + }, + Sync: &syncconf.Config{ + Enable: &enabled, + Registries: []syncconf.RegistryConfig{ + { + URLs: []string{"http://original:5000"}, + }, + }, + }, + Metrics: &extconf.MetricsConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Prometheus: &extconf.PrometheusConfig{ + Path: "/metrics", + }, + }, + Scrub: &extconf.ScrubConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Interval: 24 * time.Hour, + }, + UI: &extconf.UIConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + }, + } + + // Retrieve the ExtensionConfig + extensionConfig := cfg.CopyExtensionsConfig() + So(extensionConfig, ShouldNotBeNil) + + // Mutate the retrieved ExtensionConfig copy + disabled := false + extensionConfig.Search.Enable = &disabled + extensionConfig.Search.CVE.UpdateInterval = 2 * time.Hour + extensionConfig.Search.CVE.Trivy.DBRepository = "modified/trivy-db" + extensionConfig.Sync.Registries[0].URLs[0] = "http://modified:5000" + extensionConfig.Metrics.Prometheus.Path = "/custom/metrics" + extensionConfig.Scrub.Interval = 48 * time.Hour + extensionConfig.UI.Enable = &disabled + + // Verify that the original config is unchanged + So(*cfg.Extensions.Search.Enable, ShouldBeTrue) + So(cfg.Extensions.Search.CVE.UpdateInterval, ShouldEqual, time.Hour) + So(cfg.Extensions.Search.CVE.Trivy.DBRepository, ShouldEqual, "original/trivy-db") + So(cfg.Extensions.Sync.Registries[0].URLs[0], ShouldEqual, "http://original:5000") + So(cfg.Extensions.Metrics.Prometheus.Path, ShouldEqual, "/metrics") + So(cfg.Extensions.Scrub.Interval, ShouldEqual, 24*time.Hour) + So(*cfg.Extensions.UI.Enable, ShouldBeTrue) + + // Verify that the retrieved config has the mutations + So(*extensionConfig.Search.Enable, ShouldBeFalse) + So(extensionConfig.Search.CVE.UpdateInterval, ShouldEqual, 2*time.Hour) + So(extensionConfig.Search.CVE.Trivy.DBRepository, ShouldEqual, "modified/trivy-db") + So(extensionConfig.Sync.Registries[0].URLs[0], ShouldEqual, "http://modified:5000") + So(extensionConfig.Metrics.Prometheus.Path, ShouldEqual, "/custom/metrics") + So(extensionConfig.Scrub.Interval, ShouldEqual, 48*time.Hour) + So(*extensionConfig.UI.Enable, ShouldBeFalse) + }) + + Convey("Test that mutations to retrieved ExtensionConfig work with nil initial config", func() { + // Create a config with nil ExtensionConfig + cfg := &config.Config{ + Extensions: nil, + } + + // Retrieve the ExtensionConfig (should return nil) + extensionConfig := cfg.CopyExtensionsConfig() + So(extensionConfig, ShouldBeNil) + + // Create a new ExtensionConfig and set it + enabled := true + newExtensionConfig := &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + Metrics: &extconf.MetricsConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Prometheus: &extconf.PrometheusConfig{ + Path: "/metrics", + }, + }, + } + + // Manually set the ExtensionConfig on the original config + cfg.Extensions = newExtensionConfig + + // Now retrieve it again and verify it works + retrievedConfig := cfg.CopyExtensionsConfig() + So(retrievedConfig, ShouldNotBeNil) + + // Mutate the retrieved config + retrievedConfig.Metrics.Prometheus.Path = "/new/metrics" + + // Verify the changes are NOT reflected in original config + finalConfig := cfg.CopyExtensionsConfig() + So(finalConfig.Metrics.Prometheus.Path, ShouldEqual, "/metrics") + }) + }) + + Convey("Test ExtensionConfig copy isolation through UpdateReloadableConfig()", t, func() { + Convey("Test that ExtensionConfig copies are isolated from UpdateReloadableConfig changes", func() { + // Create initial config with ExtensionConfig + enabled := true + cfg := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + Metrics: &extconf.MetricsConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Prometheus: &extconf.PrometheusConfig{ + Path: "/metrics", + }, + }, + }, + } + + // Get initial reference to ExtensionConfig + initialExtensionConfig := cfg.CopyExtensionsConfig() + So(initialExtensionConfig, ShouldNotBeNil) + + // Verify initial state + So(initialExtensionConfig.Metrics.Prometheus.Path, ShouldEqual, "/metrics") + So(initialExtensionConfig.Sync, ShouldBeNil) + So(initialExtensionConfig.Search.CVE, ShouldBeNil) + So(initialExtensionConfig.Scrub, ShouldBeNil) + + // Create new config with updated ExtensionConfig + newConfig := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + CVE: &extconf.CVEConfig{ + UpdateInterval: time.Hour * 2, + Trivy: &extconf.TrivyConfig{ + DBRepository: "updated/trivy-db", + }, + }, + }, + Metrics: &extconf.MetricsConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Prometheus: &extconf.PrometheusConfig{ + Path: "/custom/metrics", + }, + }, + Sync: &syncconf.Config{ + Enable: &enabled, + Registries: []syncconf.RegistryConfig{ + { + URLs: []string{"http://registry1:5000", "http://registry2:5000"}, + }, + }, + }, + Scrub: &extconf.ScrubConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Interval: time.Hour * 12, + }, + }, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that the old reference remains unchanged (copy isolation) + So(initialExtensionConfig.Metrics.Prometheus.Path, ShouldEqual, "/metrics") + So(initialExtensionConfig.Sync, ShouldBeNil) + So(initialExtensionConfig.Search.CVE, ShouldBeNil) + So(initialExtensionConfig.Scrub, ShouldBeNil) + + // Verify that a new reference gets the updated data + newExtensionConfig := cfg.CopyExtensionsConfig() + So(newExtensionConfig, ShouldNotBeNil) + So(newExtensionConfig, ShouldNotEqual, initialExtensionConfig) // Different references + + So(newExtensionConfig.Metrics.Prometheus.Path, ShouldEqual, "/metrics") + So(newExtensionConfig.Sync, ShouldNotBeNil) + So(newExtensionConfig.Search.CVE, ShouldNotBeNil) + So(newExtensionConfig.Scrub, ShouldNotBeNil) + }) + + Convey("Test that old ExtensionConfig reference works with nil initial config", func() { + // Create config with nil ExtensionConfig + cfg := &config.Config{ + Extensions: nil, + } + + // Get initial reference (should be nil) + initialExtensionConfig := cfg.CopyExtensionsConfig() + So(initialExtensionConfig, ShouldBeNil) + + // Create new config with ExtensionConfig + enabled := true + newConfig := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + Metrics: &extconf.MetricsConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + Prometheus: &extconf.PrometheusConfig{ + Path: "/new/metrics", + }, + }, + }, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that a new reference now gets the data + newExtensionConfig := cfg.CopyExtensionsConfig() + So(newExtensionConfig, ShouldNotBeNil) + + // Note: UpdateReloadableConfig creates an empty ExtensionConfig when going from nil to non-nil, + // but doesn't copy the fields from newConfig.Extensions. It only updates specific parts. + // So the Search and Metrics fields will be nil in the new ExtensionConfig. + So(newExtensionConfig.Search, ShouldBeNil) + So(newExtensionConfig.Metrics, ShouldBeNil) + }) + + Convey("Test that old ExtensionConfig reference works when new config has nil ExtensionConfig", func() { + // Create initial config with ExtensionConfig + enabled := true + cfg := &config.Config{ + Extensions: &extconf.ExtensionConfig{ + Search: &extconf.SearchConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &enabled, + }, + }, + }, + } + + // Get initial reference + initialExtensionConfig := cfg.CopyExtensionsConfig() + So(initialExtensionConfig, ShouldNotBeNil) + + // Create new config with nil ExtensionConfig + newConfig := &config.Config{ + Extensions: nil, + } + + // Update the config using UpdateReloadableConfig + cfg.UpdateReloadableConfig(newConfig) + + // Verify that the old reference remains unchanged (copy isolation) + So(initialExtensionConfig, ShouldNotBeNil) + So(initialExtensionConfig.Search, ShouldNotBeNil) + + // Verify that a new reference now returns nil + newExtensionConfig := cfg.CopyExtensionsConfig() + So(newExtensionConfig, ShouldBeNil) + }) + }) + + Convey("Test UpdateReloadableConfig LDAP config updates", t, func() { + Convey("Test LDAP config is updated in UpdateReloadableConfig", func() { + // Create initial config with LDAP + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: &config.LDAPConfig{ + Address: "ldap://old-server:389", + Port: 389, + Insecure: true, + }, + }, + }, + } + + // Create new config with updated LDAP + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: &config.LDAPConfig{ + Address: "ldap://new-server:636", + Port: 636, + Insecure: false, + StartTLS: true, + }, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify LDAP config was updated + So(cfg.HTTP.Auth.LDAP, ShouldNotBeNil) + So(cfg.HTTP.Auth.LDAP.Address, ShouldEqual, "ldap://new-server:636") + So(cfg.HTTP.Auth.LDAP.Port, ShouldEqual, 636) + So(cfg.HTTP.Auth.LDAP.Insecure, ShouldBeFalse) + So(cfg.HTTP.Auth.LDAP.StartTLS, ShouldBeTrue) + }) + + Convey("Test LDAP config is set to nil when new config has nil LDAP", func() { + // Create initial config with LDAP + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: &config.LDAPConfig{ + Address: "ldap://old-server:389", + }, + }, + }, + } + + // Create new config with nil LDAP + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: nil, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify LDAP config was set to nil + So(cfg.HTTP.Auth.LDAP, ShouldBeNil) + }) + + Convey("Test LDAP config is created when going from nil to non-nil", func() { + // Create initial config with nil LDAP + cfg := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: nil, + }, + }, + } + + // Create new config with LDAP + newConfig := &config.Config{ + HTTP: config.HTTPConfig{ + Auth: &config.AuthConfig{ + LDAP: &config.LDAPConfig{ + Address: "ldap://new-server:389", + Port: 389, + }, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify LDAP config was created + So(cfg.HTTP.Auth.LDAP, ShouldNotBeNil) + So(cfg.HTTP.Auth.LDAP.Address, ShouldEqual, "ldap://new-server:389") + So(cfg.HTTP.Auth.LDAP.Port, ShouldEqual, 389) + }) + }) + + Convey("Test UpdateReloadableConfig Storage.SubPaths logic", t, func() { + Convey("Test existing SubPaths are updated", func() { + // Create initial config with SubPaths + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: true, + Dedupe: false, + GCDelay: time.Hour, + GCInterval: time.Hour * 24, + }, + "/path2": { + GC: false, + Dedupe: true, + GCDelay: time.Hour * 2, + GCInterval: time.Hour * 48, + }, + }, + }, + } + + // Create new config with updated SubPaths + newConfig := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: false, // Changed + Dedupe: true, // Changed + GCDelay: time.Hour * 2, // Changed + GCInterval: time.Hour * 12, // Changed + }, + "/path2": { + GC: true, // Changed + Dedupe: false, // Changed + GCDelay: time.Hour * 3, // Changed + GCInterval: time.Hour * 36, // Changed + }, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify SubPaths were updated + So(len(cfg.Storage.SubPaths), ShouldEqual, 2) + + // Check /path1 + path1Config := cfg.Storage.SubPaths["/path1"] + So(path1Config.GC, ShouldBeFalse) + So(path1Config.Dedupe, ShouldBeTrue) + So(path1Config.GCDelay, ShouldEqual, time.Hour*2) + So(path1Config.GCInterval, ShouldEqual, time.Hour*12) + + // Check /path2 + path2Config := cfg.Storage.SubPaths["/path2"] + So(path2Config.GC, ShouldBeTrue) + So(path2Config.Dedupe, ShouldBeFalse) + So(path2Config.GCDelay, ShouldEqual, time.Hour*3) + So(path2Config.GCInterval, ShouldEqual, time.Hour*36) + }) + + Convey("Test new SubPaths are not added (only existing ones are updated)", func() { + // Create initial config with one SubPath + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: true, + Dedupe: false, + }, + }, + }, + } + + // Create new config with additional SubPath + newConfig := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: false, // Update existing + Dedupe: true, // Update existing + }, + "/path2": { // New path - should not be added + GC: true, + Dedupe: true, + }, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify only existing SubPath was updated, new one was not added + So(len(cfg.Storage.SubPaths), ShouldEqual, 1) + _, exists := cfg.Storage.SubPaths["/path2"] + So(exists, ShouldBeFalse) // New path not added + + // Verify existing path was updated + path1Config := cfg.Storage.SubPaths["/path1"] + So(path1Config.GC, ShouldBeFalse) + So(path1Config.Dedupe, ShouldBeTrue) + }) + + Convey("Test SubPaths Retention is updated only when retention is enabled", func() { + // Create initial config with retention enabled and SubPaths + // Retention is enabled when there are policies with tag retention + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"repo1"}, + KeepTags: []config.KeepTagsPolicy{ + { + MostRecentlyPulledCount: 10, // This enables retention + }, + }, + }, + }, + }, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: true, + Dedupe: false, + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"old-repo"}, + }, + }, + }, + }, + }, + }, + } + + // Create new config with updated SubPath retention + newConfig := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"repo1"}, + KeepTags: []config.KeepTagsPolicy{ + { + MostRecentlyPulledCount: 10, // This enables retention + }, + }, + }, + }, + }, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: false, + Dedupe: true, + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"new-repo"}, + }, + }, + }, + }, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify SubPath was updated including Retention + path1Config := cfg.Storage.SubPaths["/path1"] + So(path1Config.GC, ShouldBeFalse) + So(path1Config.Dedupe, ShouldBeTrue) + So(len(path1Config.Retention.Policies), ShouldEqual, 1) + So(path1Config.Retention.Policies[0].Repositories[0], ShouldEqual, "new-repo") + }) + + Convey("Test SubPaths Retention is not updated when retention is disabled", func() { + // Create initial config with retention disabled and SubPaths + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + // No Retention config - retention disabled + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: true, + Dedupe: false, + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"old-repo"}, + }, + }, + }, + }, + }, + }, + } + + // Create new config with updated SubPath retention + newConfig := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + // No Retention config - retention disabled + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: false, + Dedupe: true, + Retention: config.ImageRetention{ + Policies: []config.RetentionPolicy{ + { + Repositories: []string{"new-repo"}, + }, + }, + }, + }, + }, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify SubPath was updated but Retention was not + path1Config := cfg.Storage.SubPaths["/path1"] + So(path1Config.GC, ShouldBeFalse) + So(path1Config.Dedupe, ShouldBeTrue) + // Retention should remain unchanged (old value) + So(len(path1Config.Retention.Policies), ShouldEqual, 1) + So(path1Config.Retention.Policies[0].Repositories[0], ShouldEqual, "old-repo") + }) + + Convey("Test SubPaths with empty new config", func() { + // Create initial config with SubPaths + cfg := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + }, + SubPaths: map[string]config.StorageConfig{ + "/path1": { + GC: true, + Dedupe: false, + }, + "/path2": { + GC: false, + Dedupe: true, + }, + }, + }, + } + + // Create new config with empty SubPaths + newConfig := &config.Config{ + Storage: config.GlobalStorageConfig{ + StorageConfig: config.StorageConfig{ + GC: true, + Dedupe: false, + }, + SubPaths: map[string]config.StorageConfig{}, + }, + } + + // Update the config + cfg.UpdateReloadableConfig(newConfig) + + // Verify existing SubPaths remain unchanged (no updates applied) + So(len(cfg.Storage.SubPaths), ShouldEqual, 2) + path1Config := cfg.Storage.SubPaths["/path1"] + So(path1Config.GC, ShouldBeTrue) // Unchanged + So(path1Config.Dedupe, ShouldBeFalse) // Unchanged + path2Config := cfg.Storage.SubPaths["/path2"] + So(path2Config.GC, ShouldBeFalse) // Unchanged + So(path2Config.Dedupe, ShouldBeTrue) // Unchanged + }) }) } diff --git a/pkg/api/controller.go b/pkg/api/controller.go index 93df422e..30666a7f 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -22,15 +22,14 @@ import ( "zotregistry.dev/zot/v2/pkg/api/config" "zotregistry.dev/zot/v2/pkg/common" ext "zotregistry.dev/zot/v2/pkg/extensions" - extconf "zotregistry.dev/zot/v2/pkg/extensions/config" - "zotregistry.dev/zot/v2/pkg/extensions/events" - "zotregistry.dev/zot/v2/pkg/extensions/monitoring" - "zotregistry.dev/zot/v2/pkg/log" - "zotregistry.dev/zot/v2/pkg/meta" + events "zotregistry.dev/zot/v2/pkg/extensions/events" + monitoring "zotregistry.dev/zot/v2/pkg/extensions/monitoring" + log "zotregistry.dev/zot/v2/pkg/log" + meta "zotregistry.dev/zot/v2/pkg/meta" mTypes "zotregistry.dev/zot/v2/pkg/meta/types" - "zotregistry.dev/zot/v2/pkg/scheduler" - "zotregistry.dev/zot/v2/pkg/storage" - "zotregistry.dev/zot/v2/pkg/storage/gc" + scheduler "zotregistry.dev/zot/v2/pkg/scheduler" + storage "zotregistry.dev/zot/v2/pkg/storage" + gc "zotregistry.dev/zot/v2/pkg/storage/gc" ) const ( @@ -139,12 +138,13 @@ func (c *Controller) Run() error { engine := mux.NewRouter() // rate-limit HTTP requests if enabled - if c.Config.HTTP.Ratelimit != nil { - if c.Config.HTTP.Ratelimit.Rate != nil { - engine.Use(RateLimiter(c, *c.Config.HTTP.Ratelimit.Rate)) + ratelimitConfig := c.Config.CopyRatelimit() + if ratelimitConfig != nil { + if ratelimitConfig.Rate != nil { + engine.Use(RateLimiter(c, *ratelimitConfig.Rate)) } - for _, mrlim := range c.Config.HTTP.Ratelimit.Methods { + for _, mrlim := range ratelimitConfig.Methods { engine.Use(MethodRateLimiter(c, mrlim.Method, mrlim.Rate)) } } @@ -161,13 +161,14 @@ func (c *Controller) Run() error { c.Router = engine c.Router.UseEncodedPath() - monitoring.SetServerInfo(c.Metrics, c.Config.Commit, c.Config.BinaryType, c.Config.GoVersion, - c.Config.DistSpecVersion) + commit, binaryType, goVersion, distSpecVersion := c.Config.GetVersionInfo() + monitoring.SetServerInfo(c.Metrics, commit, binaryType, goVersion, distSpecVersion) //nolint: contextcheck _ = NewRouteHandler(c) - addr := fmt.Sprintf("%s:%s", c.Config.HTTP.Address, c.Config.HTTP.Port) + port := c.Config.GetHTTPPort() + addr := fmt.Sprintf("%s:%s", c.Config.GetHTTPAddress(), port) server := &http.Server{ Addr: addr, Handler: c.Router, @@ -182,10 +183,10 @@ func (c *Controller) Run() error { return err } - if c.Config.HTTP.Port == "0" || c.Config.HTTP.Port == "" { + if port == "0" || port == "" { chosenAddr, ok := listener.Addr().(*net.TCPAddr) if !ok { - c.Log.Error().Str("port", c.Config.HTTP.Port).Msg("invalid addr type") + c.Log.Error().Str("port", port).Msg("invalid addr type") return errors.ErrBadType } @@ -196,12 +197,13 @@ func (c *Controller) Run() error { "port is unspecified, listening on kernel chosen port", ) } else { - chosenPort, _ := strconv.ParseInt(c.Config.HTTP.Port, 10, 64) + chosenPort, _ := strconv.ParseInt(port, 10, 32) c.chosenPort = int(chosenPort) } - if c.Config.HTTP.TLS != nil && c.Config.HTTP.TLS.Key != "" && c.Config.HTTP.TLS.Cert != "" { + tlsConfig := c.Config.CopyTLSConfig() + if tlsConfig != nil && tlsConfig.Key != "" && tlsConfig.Cert != "" { server.TLSConfig = &tls.Config{ CipherSuites: []uint16{ tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, @@ -219,15 +221,15 @@ func (c *Controller) Run() error { MinVersion: tls.VersionTLS12, } - if c.Config.HTTP.TLS.CACert != "" { + if tlsConfig.CACert != "" { clientAuth := tls.VerifyClientCertIfGiven if c.Config.IsMTLSAuthEnabled() { clientAuth = tls.RequireAndVerifyClientCert } - caCert, err := os.ReadFile(c.Config.HTTP.TLS.CACert) + caCert, err := os.ReadFile(tlsConfig.CACert) if err != nil { - c.Log.Error().Err(err).Str("caCert", c.Config.HTTP.TLS.CACert).Msg("failed to read file") + c.Log.Error().Err(err).Str("caCert", tlsConfig.CACert).Msg("failed to read file") return err } @@ -246,7 +248,7 @@ func (c *Controller) Run() error { c.Healthz.Ready() - return server.ServeTLS(listener, c.Config.HTTP.TLS.Cert, c.Config.HTTP.TLS.Key) + return server.ServeTLS(listener, tlsConfig.Cert, tlsConfig.Key) } c.Healthz.Ready() @@ -267,10 +269,9 @@ func (c *Controller) Init() error { DumpRuntimeParams(c.Log) var enabled bool - if c.Config != nil && - c.Config.Extensions != nil && - c.Config.Extensions.Metrics != nil && - *c.Config.Extensions.Metrics.Enable { + extensionsConfig := c.Config.CopyExtensionsConfig() + + if extensionsConfig.IsMetricsEnabled() { enabled = true } @@ -291,8 +292,10 @@ func (c *Controller) Init() error { c.InitCVEInfo() c.Healthz.Started() - if c.Config.IsHtpasswdAuthEnabled() { - err := c.HTPasswdWatcher.ChangeFile(c.Config.HTTP.Auth.HTPasswd.Path) + // Get auth config safely + authConfig := c.Config.CopyAuthConfig() + if authConfig.IsHtpasswdAuthEnabled() { + err := c.HTPasswdWatcher.ChangeFile(authConfig.HTPasswd.Path) if err != nil { return err } @@ -303,9 +306,7 @@ func (c *Controller) Init() error { func (c *Controller) InitCVEInfo() { // Enable CVE extension if extension config is provided - if c.Config != nil && c.Config.Extensions != nil { - c.CveScanner = ext.GetCveScanner(c.Config, c.StoreController, c.MetaDB, c.Log) - } + c.CveScanner = ext.GetCveScanner(c.Config, c.StoreController, c.MetaDB, c.Log) } func (c *Controller) InitImageStore() error { @@ -323,11 +324,12 @@ func (c *Controller) InitImageStore() error { func (c *Controller) initCookieStore() error { // setup sessions cookie store used to preserve logged in user in web sessions - if c.Config.IsBasicAuthnEnabled() { + if c.Config.HTTP.Auth.IsBasicAuthnEnabled() { if c.Config.HTTP.Auth.SessionHashKey == nil { c.Log.Warn().Msg("hashKey is not set in config, generating a random one") - c.Config.HTTP.Auth.SessionHashKey = securecookie.GenerateRandomKey(64) //nolint: gomnd + key := securecookie.GenerateRandomKey(64) //nolint: gomnd + c.Config.HTTP.Auth.SessionHashKey = key } cookieStore, err := NewCookieStore(c.Config.HTTP.Auth, c.StoreController, c.Log) @@ -343,9 +345,16 @@ func (c *Controller) initCookieStore() error { func (c *Controller) InitMetaDB() error { // init metaDB if search is enabled or we need to store user profiles, api keys or signatures - if c.Config.IsSearchEnabled() || c.Config.IsBasicAuthnEnabled() || c.Config.IsImageTrustEnabled() || + // Get auth config safely + authConfig := c.Config.CopyAuthConfig() + extensionsConfig := c.Config.CopyExtensionsConfig() + + if extensionsConfig.IsSearchEnabled() || authConfig.IsBasicAuthnEnabled() || extensionsConfig.IsImageTrustEnabled() || c.Config.IsRetentionEnabled() { - driver, err := meta.New(c.Config.Storage.StorageConfig, c.Log) //nolint:contextcheck + // Get storage config safely + storageConfig := c.Config.CopyStorageConfig() + + driver, err := meta.New(storageConfig.StorageConfig, c.Log) //nolint:contextcheck if err != nil { return err } @@ -383,74 +392,25 @@ func (c *Controller) InitEventRecorder() error { } func (c *Controller) LoadNewConfig(newConfig *config.Config) { - // reload access control config - c.Config.HTTP.AccessControl = newConfig.HTTP.AccessControl + // Update only reloadable config fields atomically + c.Config.UpdateReloadableConfig(newConfig) - if c.Config.HTTP.Auth != nil { - c.Config.HTTP.Auth.HTPasswd = newConfig.HTTP.Auth.HTPasswd - c.Config.HTTP.Auth.LDAP = newConfig.HTTP.Auth.LDAP - - err := c.HTPasswdWatcher.ChangeFile(c.Config.HTTP.Auth.HTPasswd.Path) + // Operations that need to happen after config update + authConfig := c.Config.CopyAuthConfig() + if authConfig.IsHtpasswdAuthEnabled() { + err := c.HTPasswdWatcher.ChangeFile(authConfig.HTPasswd.Path) if err != nil { c.Log.Error().Err(err).Msg("failed to change watched htpasswd file") } - - if c.LDAPClient != nil { - c.LDAPClient.lock.Lock() - c.LDAPClient.BindDN = newConfig.HTTP.Auth.LDAP.BindDN() - c.LDAPClient.BindPassword = newConfig.HTTP.Auth.LDAP.BindPassword() - c.LDAPClient.lock.Unlock() - } } else { _ = c.HTPasswdWatcher.ChangeFile("") } - // reload periodical gc config - c.Config.Storage.GC = newConfig.Storage.GC - c.Config.Storage.Dedupe = newConfig.Storage.Dedupe - c.Config.Storage.GCDelay = newConfig.Storage.GCDelay - c.Config.Storage.GCInterval = newConfig.Storage.GCInterval - // only if we have a metaDB already in place - if c.Config.IsRetentionEnabled() { - c.Config.Storage.Retention = newConfig.Storage.Retention - } - - for subPath, storageConfig := range newConfig.Storage.SubPaths { - subPathConfig, ok := c.Config.Storage.SubPaths[subPath] - if ok { - subPathConfig.GC = storageConfig.GC - subPathConfig.Dedupe = storageConfig.Dedupe - subPathConfig.GCDelay = storageConfig.GCDelay - subPathConfig.GCInterval = storageConfig.GCInterval - // only if we have a metaDB already in place - if c.Config.IsRetentionEnabled() { - subPathConfig.Retention = storageConfig.Retention - } - - c.Config.Storage.SubPaths[subPath] = subPathConfig - } - } - - // reload background tasks - if newConfig.Extensions != nil { - if c.Config.Extensions == nil { - c.Config.Extensions = &extconf.ExtensionConfig{} - } - - // reload sync extension - c.Config.Extensions.Sync = newConfig.Extensions.Sync - - // reload only if search is enabled and reloaded config has search extension (can't setup routes at this stage) - if c.Config.Extensions.Search != nil && *c.Config.Extensions.Search.Enable { - if newConfig.Extensions.Search != nil { - c.Config.Extensions.Search.CVE = newConfig.Extensions.Search.CVE - } - } - - // reload scrub extension - c.Config.Extensions.Scrub = newConfig.Extensions.Scrub - } else { - c.Config.Extensions = nil + if c.LDAPClient != nil && authConfig.IsLdapAuthEnabled() { + c.LDAPClient.lock.Lock() + c.LDAPClient.BindDN = authConfig.LDAP.BindDN() + c.LDAPClient.BindPassword = authConfig.LDAP.BindPassword() + c.LDAPClient.lock.Unlock() } c.InitCVEInfo() @@ -463,6 +423,11 @@ func (c *Controller) Shutdown() { // stop all background tasks c.StopBackgroundTasks() + // Stop metrics server to prevent resource leaks (only during full shutdown) + if c.Metrics != nil { + c.Metrics.Stop() + } + if c.Server != nil { ctx := context.Background() _ = c.Server.Shutdown(ctx) @@ -479,73 +444,90 @@ func (c *Controller) StopBackgroundTasks() { if c.taskScheduler != nil { c.taskScheduler.Shutdown() } + + // Close HTPasswdWatcher to prevent resource leaks + if c.HTPasswdWatcher != nil { + _ = c.HTPasswdWatcher.Close() + } } func (c *Controller) StartBackgroundTasks() { c.taskScheduler = scheduler.NewScheduler(c.Config, c.Metrics, c.Log) c.taskScheduler.RunScheduler() + // Start HTPasswdWatcher goroutine + if c.HTPasswdWatcher != nil { + c.HTPasswdWatcher.Run() + } + // Enable running garbage-collect periodically for DefaultStore - if c.Config.Storage.GC { + storageConfig := c.Config.CopyStorageConfig() + if storageConfig.GC { gc := gc.NewGarbageCollect(c.StoreController.DefaultStore, c.MetaDB, gc.Options{ - Delay: c.Config.Storage.GCDelay, - ImageRetention: c.Config.Storage.Retention, + Delay: storageConfig.GCDelay, + ImageRetention: storageConfig.Retention, }, c.Audit, c.Log) - gc.CleanImageStorePeriodically(c.Config.Storage.GCInterval, c.taskScheduler) + gc.CleanImageStorePeriodically(storageConfig.GCInterval, c.taskScheduler) } // Enable running dedupe blobs both ways (dedupe or restore deduped blobs) c.StoreController.DefaultStore.RunDedupeBlobs(time.Duration(0), c.taskScheduler) // Enable extensions if extension config is provided for DefaultStore - if c.Config != nil && c.Config.Extensions != nil { - ext.EnableMetricsExtension(c.Config, c.Log, c.Config.Storage.RootDirectory) - ext.EnableSearchExtension(c.Config, c.StoreController, c.MetaDB, c.taskScheduler, c.CveScanner, c.Log) - } + extensionsConfig := c.Config.CopyExtensionsConfig() + + // Always call EnableSearchExtension to ensure proper logging, even when search is disabled + ext.EnableSearchExtension(c.Config, c.StoreController, c.MetaDB, c.taskScheduler, c.CveScanner, c.Log) + + // Always call EnableMetricsExtension to ensure proper logging, even when metrics is disabled + ext.EnableMetricsExtension(c.Config, c.Log, storageConfig.RootDirectory) + // runs once if metrics are enabled & imagestore is local - if c.Config.IsMetricsEnabled() && c.Config.Storage.StorageDriver == nil { + if extensionsConfig.IsMetricsEnabled() && storageConfig.StorageDriver == nil { c.StoreController.DefaultStore.PopulateStorageMetrics(time.Duration(0), c.taskScheduler) } - if c.Config.Storage.SubPaths != nil { - for route, storageConfig := range c.Config.Storage.SubPaths { + if storageConfig.SubPaths != nil { + for route, subStorageConfig := range storageConfig.SubPaths { // Enable running garbage-collect periodically for subImageStore - if storageConfig.GC { + if subStorageConfig.GC { gc := gc.NewGarbageCollect(c.StoreController.SubStore[route], c.MetaDB, gc.Options{ - Delay: storageConfig.GCDelay, - ImageRetention: storageConfig.Retention, + Delay: subStorageConfig.GCDelay, + ImageRetention: subStorageConfig.Retention, }, c.Audit, c.Log) - gc.CleanImageStorePeriodically(storageConfig.GCInterval, c.taskScheduler) + gc.CleanImageStorePeriodically(subStorageConfig.GCInterval, c.taskScheduler) } // Enable extensions if extension config is provided for subImageStore - if c.Config != nil && c.Config.Extensions != nil { - ext.EnableMetricsExtension(c.Config, c.Log, storageConfig.RootDirectory) - } + ext.EnableMetricsExtension(c.Config, c.Log, subStorageConfig.RootDirectory) // Enable running dedupe blobs both ways (dedupe or restore deduped blobs) for subpaths substore := c.StoreController.SubStore[route] if substore != nil { substore.RunDedupeBlobs(time.Duration(0), c.taskScheduler) - if c.Config.IsMetricsEnabled() && c.Config.Storage.StorageDriver == nil { + if extensionsConfig.IsMetricsEnabled() && storageConfig.StorageDriver == nil { substore.PopulateStorageMetrics(time.Duration(0), c.taskScheduler) } } } } - if c.Config.Extensions != nil { - ext.EnableScrubExtension(c.Config, c.Log, c.StoreController, c.taskScheduler) - //nolint: contextcheck - syncOnDemand, err := ext.EnableSyncExtension(c.Config, c.MetaDB, c.StoreController, c.taskScheduler, c.Log) - if err != nil { - c.Log.Error().Err(err).Msg("failed to start sync extension") - } + // Always call EnableScrubExtension to ensure proper logging, even when scrub is disabled + ext.EnableScrubExtension(c.Config, c.Log, c.StoreController, c.taskScheduler) + // Always call EnableSyncExtension to ensure proper logging, even when sync is disabled + //nolint: contextcheck + syncOnDemand, err := ext.EnableSyncExtension(c.Config, c.MetaDB, c.StoreController, c.taskScheduler, c.Log) + if err != nil { + c.Log.Error().Err(err).Msg("failed to start sync extension") + } + + // Only set SyncOnDemand if sync is actually enabled + if extensionsConfig.IsSyncEnabled() { c.SyncOnDemand = syncOnDemand } diff --git a/pkg/api/controller_test.go b/pkg/api/controller_test.go index cef57cc5..03c40dd6 100644 --- a/pkg/api/controller_test.go +++ b/pkg/api/controller_test.go @@ -1296,7 +1296,9 @@ func TestScaleOutRequestProxy(t *testing.T) { cm := test.NewControllerManager(ctrlr) cm.StartAndWait(port) - defer cm.StopServer() + defer func(cm test.ControllerManager) { + cm.StopServer() + }(cm) } Convey("All 3 controllers should start up and respond without error", func() { @@ -1393,7 +1395,9 @@ func TestScaleOutRequestProxy(t *testing.T) { cm := test.NewControllerManager(ctrlr) cm.StartAndWait(port) - defer cm.StopServer() + defer func(cm test.ControllerManager) { + cm.StopServer() + }(cm) } Convey("All 3 controllers should start up and respond without error", func() { @@ -1470,7 +1474,9 @@ func TestScaleOutRequestProxy(t *testing.T) { cm := test.NewControllerManager(ctrlr) cm.StartAndWait(port) - defer cm.StopServer() + defer func(cm test.ControllerManager) { + cm.StopServer() + }(cm) } caCert, err := os.ReadFile(CACert) @@ -1534,7 +1540,9 @@ func TestScaleOutRequestProxy(t *testing.T) { cm := test.NewControllerManager(ctrlr) cm.StartAndWait(port) - defer cm.StopServer() + defer func(cm test.ControllerManager) { + cm.StopServer() + }(cm) } caCert, err := os.ReadFile(CACert) @@ -1599,7 +1607,9 @@ func TestScaleOutRequestProxy(t *testing.T) { cm := test.NewControllerManager(ctrlr) cm.StartAndWait(port) - defer cm.StopServer() + defer func(cm test.ControllerManager) { + cm.StopServer() + }(cm) } caCert, err := os.ReadFile(CACert) @@ -7171,6 +7181,8 @@ func TestCrossRepoMount(t *testing.T) { ctlr.Config.Storage.Dedupe = true ctlr.Config.Storage.GC = false ctlr.Config.Storage.RootDirectory = newDir + + ctlr = api.NewController(ctlr.Config) cm = test.NewControllerManager(ctlr) //nolint: varnamelen cm.StartAndWait(port) diff --git a/pkg/api/htpasswd.go b/pkg/api/htpasswd.go index 6a6a6236..5bdbcbdb 100644 --- a/pkg/api/htpasswd.go +++ b/pkg/api/htpasswd.go @@ -100,76 +100,129 @@ func (s *HTPasswd) Authenticate(username, passphrase string) (ok, present bool) // HTPasswdWatcher helper which triggers htpasswd reload on file change event. // -// Cannot be restarted. +// Can be restarted by calling Run() again after Close(). type HTPasswdWatcher struct { htp *HTPasswd filePath string watcher *fsnotify.Watcher + ctx context.Context //nolint:containedctx // Context is needed for watcher lifecycle management cancel context.CancelFunc log log.Logger + mu sync.Mutex } -// NewHTPasswdWatcher create and start watcher. +// NewHTPasswdWatcher creates a new watcher instance. func NewHTPasswdWatcher(htp *HTPasswd, filePath string) (*HTPasswdWatcher, error) { - watcher, err := fsnotify.NewWatcher() - if err != nil { - return nil, err - } - - if filePath != "" { - err = watcher.Add(filePath) - if err != nil { - return nil, errors.Join(err, watcher.Close()) - } - } - - // background event processor job context - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) - ret := &HTPasswdWatcher{ htp: htp, filePath: filePath, - watcher: watcher, - cancel: cancel, log: htp.log, } - go func() { - defer ret.watcher.Close() //nolint: errcheck - - for { - select { - case ev := <-ret.watcher.Events: - if ev.Op != fsnotify.Write { - continue - } - - ret.log.Info().Str("htpasswd-file", ret.filePath).Msg("htpasswd file changed, trying to reload config") - - err := ret.htp.Reload(ret.filePath) - if err != nil { - ret.log.Warn().Err(err).Str("htpasswd-file", ret.filePath).Msg("failed to reload file") - } - - case err := <-ret.watcher.Errors: - ret.log.Error().Err(err).Str("htpasswd-file", ret.filePath).Msg("failed to fsnotfy, got error while watching file") - - case <-ctx.Done(): - ret.log.Debug().Msg("htpasswd watcher terminating...") - - return - } - } - }() - return ret, nil } +// Run starts the watcher goroutine. +func (s *HTPasswdWatcher) Run() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.ctx != nil { + return // Already running + } + + // Create fresh fsnotify watcher for this run + watcher, err := fsnotify.NewWatcher() + if err != nil { + s.log.Error().Err(err).Msg("failed to create fsnotify watcher") + + return + } + + // Only add file to watcher if we have a file to watch + if s.filePath != "" { + err = watcher.Add(s.filePath) + if err != nil { + s.log.Error().Err(err).Str("htpasswd-file", s.filePath).Msg("failed to add file to watcher") + watcher.Close() //nolint: errcheck + + return + } + } + + // Create context and start goroutine + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + s.ctx = ctx + s.cancel = cancel + s.watcher = watcher + + go func() { + defer func() { + s.mu.Lock() + defer s.mu.Unlock() + + // Clean up watcher + if s.watcher != nil { + s.watcher.Close() //nolint: errcheck + s.watcher = nil + } + + // Clear context to indicate not running + s.ctx = nil + s.cancel = nil + }() + + for { + select { + case <-ctx.Done(): + s.log.Debug().Msg("htpasswd watcher terminating...") + + return + + case ev := <-watcher.Events: + if ev.Op != fsnotify.Write { + continue + } + + s.log.Info().Str("htpasswd-file", s.filePath).Msg("htpasswd file changed, trying to reload config") + + err := s.htp.Reload(s.filePath) + if err != nil { + s.log.Warn().Err(err).Str("htpasswd-file", s.filePath).Msg("failed to reload file") + } + + case err := <-watcher.Errors: + // Only log errors if we're actually watching a file + if s.filePath != "" { + s.log.Error().Err(err).Str("htpasswd-file", s.filePath).Msg("failed to fsnotfy, got error while watching file") + } + } + } + }() +} + // ChangeFile changes monitored file. Empty string clears store. func (s *HTPasswdWatcher) ChangeFile(filePath string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // If watcher is not running, just update the filePath for when Run() is called + if s.watcher == nil { + s.filePath = filePath + if filePath == "" { + s.htp.Clear() + } else { + return s.htp.Reload(filePath) + } + + return nil + } + + // Remove old file if it exists if s.filePath != "" { - err := s.watcher.Remove(s.filePath) - if err != nil { + if err := s.watcher.Remove(s.filePath); err != nil && !errors.Is(err, fsnotify.ErrNonExistentWatch) { + // Ignore "can't remove non-existent watch" errors as they can happen + // due to race conditions or files being removed externally return err } } @@ -192,7 +245,21 @@ func (s *HTPasswdWatcher) ChangeFile(filePath string) error { } func (s *HTPasswdWatcher) Close() error { - s.cancel() + s.mu.Lock() + defer s.mu.Unlock() + + if s.ctx == nil { + return nil // Already closed/not running + } + + // Cancel context to signal goroutine to stop + if s.cancel != nil { + s.cancel() + } + + // The goroutine will clean up s.ctx, s.cancel, and s.watcher in its defer + // We just need to wait for it to finish by checking if s.ctx becomes nil + // This is safe because the goroutine sets s.ctx = nil in its defer return nil } diff --git a/pkg/api/htpasswd_test.go b/pkg/api/htpasswd_test.go index f37bd2ff..fdd7f84d 100644 --- a/pkg/api/htpasswd_test.go +++ b/pkg/api/htpasswd_test.go @@ -12,7 +12,7 @@ import ( test "zotregistry.dev/zot/v2/pkg/test/common" ) -func TestHTPasswdWatcher(t *testing.T) { +func TestHTPasswdWatcherOriginal(t *testing.T) { logger := log.NewLogger("DEBUG", "") Convey("reload htpasswd", t, func(c C) { @@ -28,6 +28,9 @@ func TestHTPasswdWatcher(t *testing.T) { htw, err := api.NewHTPasswdWatcher(htp, "") So(err, ShouldBeNil) + // Start the watcher goroutine + htw.Run() + defer htw.Close() //nolint: errcheck _, present := htp.Get(username) @@ -62,3 +65,491 @@ func TestHTPasswdWatcher(t *testing.T) { So(present, ShouldBeTrue) }) } + +func TestHTPasswdWatcher(t *testing.T) { + logger := log.NewLogger("DEBUG", "") + + Convey("Test HTPasswdWatcher comprehensive functionality", t, func() { + Convey("Test basic operations and lifecycle", func() { + // Create a buffer to capture log output + logBuffer, multiWriter := test.CreateLogCapturingWriter(os.Stdout) + capturingLogger := log.NewLoggerWithWriter("debug", multiWriter) + + htp := api.NewHTPasswd(capturingLogger) + htw, err := api.NewHTPasswdWatcher(htp, "") + So(err, ShouldBeNil) + + // Test Run() and Close() operations + So(func() { htw.Run() }, ShouldNotPanic) + time.Sleep(10 * time.Millisecond) + So(func() { htw.Run() }, ShouldNotPanic) // Idempotent + time.Sleep(10 * time.Millisecond) + So(func() { htw.Close() }, ShouldNotPanic) + time.Sleep(10 * time.Millisecond) + So(htw.Close(), ShouldBeNil) // Idempotent + + // Verify goroutine termination + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + }) + + Convey("Test ChangeFile() operations and file watching", func() { + username1, _ := test.GenerateRandomString() + password1, _ := test.GenerateRandomString() + username2, _ := test.GenerateRandomString() + password2, _ := test.GenerateRandomString() + + htpasswdPath1 := test.MakeHtpasswdFileFromString(test.GetCredString(username1, password1)) + htpasswdPath2 := test.MakeHtpasswdFileFromString(test.GetCredString(username2, password2)) + + defer os.Remove(htpasswdPath1) + defer os.Remove(htpasswdPath2) + + htp := api.NewHTPasswd(logger) + htw, err := api.NewHTPasswdWatcher(htp, "") + So(err, ShouldBeNil) + + // Test ChangeFile() when not running + err = htw.ChangeFile(htpasswdPath1) + So(err, ShouldBeNil) + ok, present := htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Start watcher and test ChangeFile() when running + htw.Run() + defer htw.Close() + time.Sleep(10 * time.Millisecond) + + // Change to second file + err = htw.ChangeFile(htpasswdPath2) + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + ok, present = htp.Authenticate(username2, password2) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + _, present = htp.Authenticate(username1, password1) + So(present, ShouldBeFalse) + + // Test ChangeFile() to empty string (clear store) + err = htw.ChangeFile("") + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + _, present = htp.Authenticate(username2, password2) + So(present, ShouldBeFalse) + + // Test ChangeFile() with non-existent file + err = htw.ChangeFile("/non/existent/path") + So(err, ShouldNotBeNil) + + // Test file change detection and reload + err = htw.ChangeFile(htpasswdPath1) + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + ok, present = htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Change file content and verify automatic reload + err = os.WriteFile(htpasswdPath1, []byte(test.GetCredString(username1, password2)), 0o600) + So(err, ShouldBeNil) + time.Sleep(100 * time.Millisecond) + ok, present = htp.Authenticate(username1, password2) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Test multiple users + multiUserContent := test.GetCredString(username1, password1) + "\n" + test.GetCredString(username2, password2) + err = os.WriteFile(htpasswdPath1, []byte(multiUserContent), 0o600) + So(err, ShouldBeNil) + time.Sleep(100 * time.Millisecond) + ok, present = htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + ok, present = htp.Authenticate(username2, password2) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Test invalid content (clears store) + err = os.WriteFile(htpasswdPath1, []byte("invalid-content"), 0o600) + So(err, ShouldBeNil) + time.Sleep(100 * time.Millisecond) + _, present = htp.Authenticate(username1, password1) + So(present, ShouldBeFalse) + + // Test empty file (clears store) + err = os.WriteFile(htpasswdPath1, []byte(""), 0o600) + So(err, ShouldBeNil) + time.Sleep(100 * time.Millisecond) + _, present = htp.Authenticate(username2, password2) + So(present, ShouldBeFalse) + }) + + Convey("Test restart capability, edge cases, and file operations", func() { + // Create a buffer to capture log output + logBuffer, multiWriter := test.CreateLogCapturingWriter(os.Stdout) + capturingLogger := log.NewLoggerWithWriter("debug", multiWriter) + + username1, _ := test.GenerateRandomString() + password1, _ := test.GenerateRandomString() + username2, _ := test.GenerateRandomString() + password2, _ := test.GenerateRandomString() + + htpasswdPath1 := test.MakeHtpasswdFileFromString(test.GetCredString(username1, password1)) + htpasswdPath2 := test.MakeHtpasswdFileFromString(test.GetCredString(username2, password2)) + + defer os.Remove(htpasswdPath1) + defer os.Remove(htpasswdPath2) + + htp := api.NewHTPasswd(capturingLogger) + htw, err := api.NewHTPasswdWatcher(htp, htpasswdPath1) + So(err, ShouldBeNil) + + // Test restart capability + htw.Run() + time.Sleep(10 * time.Millisecond) + err = htw.ChangeFile(htpasswdPath1) + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + ok, present := htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Close and restart + So(htw.Close(), ShouldBeNil) + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + htw.Run() + time.Sleep(10 * time.Millisecond) + + // Change file after restart + err = htw.ChangeFile(htpasswdPath2) + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + ok, present = htp.Authenticate(username2, password2) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Test file becomes inaccessible + os.Remove(htpasswdPath2) + time.Sleep(100 * time.Millisecond) + ok, present = htp.Authenticate(username2, password2) + So(ok, ShouldBeTrue) // User should still be present + So(present, ShouldBeTrue) + + // Test file rename (should not trigger reload) + htpasswdPath3 := test.MakeHtpasswdFileFromString(test.GetCredString(username1, password1)) + defer os.Remove(htpasswdPath3) + err = htw.ChangeFile(htpasswdPath3) + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + ok, present = htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + newPath := htpasswdPath3 + ".new" + err = os.Rename(htpasswdPath3, newPath) + So(err, ShouldBeNil) + + defer os.Remove(newPath) + time.Sleep(100 * time.Millisecond) + ok, _ = htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) // User should still be present + + // Test file permission change (should not trigger reload) + err = os.Chmod(newPath, 0o000) + So(err, ShouldBeNil) + + defer func() { _ = os.Chmod(newPath, 0o644) }() + time.Sleep(100 * time.Millisecond) + ok, _ = htp.Authenticate(username1, password1) + So(ok, ShouldBeTrue) // User should still be present + + // Test with non-existent directory + htw2, err := api.NewHTPasswdWatcher(htp, "/non/existent/dir/htpasswd") + So(err, ShouldBeNil) + So(func() { htw2.Run() }, ShouldNotPanic) + time.Sleep(10 * time.Millisecond) + So(htw2.Close(), ShouldBeNil) + // 1 termination message + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Test with very long file path + longPath := "/tmp/" + for i := 0; i < 100; i++ { + longPath += "verylongdirname" + } + longPath += "/htpasswd" + htw3, err := api.NewHTPasswdWatcher(htp, longPath) + So(err, ShouldBeNil) + So(func() { htw3.Run() }, ShouldNotPanic) + time.Sleep(10 * time.Millisecond) + So(htw3.Close(), ShouldBeNil) + // 1 termination message + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Clean up + So(htw.Close(), ShouldBeNil) + // 1 termination message + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + }) + + Convey("Test concurrent operations and goroutine cleanup", func() { + // Create a buffer to capture log output + logBuffer, multiWriter := test.CreateLogCapturingWriter(os.Stdout) + capturingLogger := log.NewLoggerWithWriter("debug", multiWriter) + + username1, _ := test.GenerateRandomString() + password1, _ := test.GenerateRandomString() + username2, _ := test.GenerateRandomString() + password2, _ := test.GenerateRandomString() + + htpasswdPath1 := test.MakeHtpasswdFileFromString(test.GetCredString(username1, password1)) + htpasswdPath2 := test.MakeHtpasswdFileFromString(test.GetCredString(username2, password2)) + + defer os.Remove(htpasswdPath1) + defer os.Remove(htpasswdPath2) + + htp := api.NewHTPasswd(capturingLogger) + htw, err := api.NewHTPasswdWatcher(htp, "") + So(err, ShouldBeNil) + + // Test concurrent Run() and Close() + go func() { + for i := 0; i < 5; i++ { + htw.Run() + time.Sleep(1 * time.Millisecond) + } + }() + + go func() { + for i := 0; i < 5; i++ { + htw.Close() + time.Sleep(1 * time.Millisecond) + } + }() + + time.Sleep(50 * time.Millisecond) + So(func() { htw.Close() }, ShouldNotPanic) + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Test concurrent ChangeFile() operations + htw.Run() + defer htw.Close() + + go func() { + for i := 0; i < 3; i++ { + _ = htw.ChangeFile(htpasswdPath1) + + time.Sleep(1 * time.Millisecond) + } + }() + + go func() { + for i := 0; i < 3; i++ { + _ = htw.ChangeFile(htpasswdPath2) + + time.Sleep(1 * time.Millisecond) + } + }() + + time.Sleep(50 * time.Millisecond) + + // At least one user should be present + ok1, present1 := htp.Authenticate(username1, password1) + ok2, present2 := htp.Authenticate(username2, password2) + So(present1 || present2, ShouldBeTrue) + So(ok1 || ok2, ShouldBeTrue) + + // Test goroutine cleanup with multiple verification methods + htw2, err := api.NewHTPasswdWatcher(htp, "") + So(err, ShouldBeNil) + + // Start watcher + htw2.Run() + time.Sleep(10 * time.Millisecond) + + // Close watcher + So(htw2.Close(), ShouldBeNil) + + // Wait for goroutine to terminate (check log messages) + // 1 termination message + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Verify we can restart the watcher (indicates proper cleanup) + htw2.Run() + time.Sleep(10 * time.Millisecond) + So(htw2.Close(), ShouldBeNil) + // 1 termination message + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Test multiple Run/Close cycles + for i := 0; i < 3; i++ { + htw2.Run() + time.Sleep(10 * time.Millisecond) + So(htw2.Close(), ShouldBeNil) + time.Sleep(50 * time.Millisecond) // Give time for termination + } + }) + + Convey("Test goroutine termination with comprehensive log verification", func() { + // Create a buffer to capture log output + logBuffer, multiWriter := test.CreateLogCapturingWriter(os.Stdout) + capturingLogger := log.NewLoggerWithWriter("debug", multiWriter) + + // Test 1: Basic termination verification (no file watching) + htp1 := api.NewHTPasswd(capturingLogger) + htw1, err := api.NewHTPasswdWatcher(htp1, "") + So(err, ShouldBeNil) + + htw1.Run() + time.Sleep(10 * time.Millisecond) + So(htw1.Close(), ShouldBeNil) + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Test 2: File watching with fsnotify resources cleanup + username, _ := test.GenerateRandomString() + password, _ := test.GenerateRandomString() + htpasswdPath := test.MakeHtpasswdFileFromString(test.GetCredString(username, password)) + + defer os.Remove(htpasswdPath) + + htp2 := api.NewHTPasswd(capturingLogger) + htw2, err := api.NewHTPasswdWatcher(htp2, htpasswdPath) + So(err, ShouldBeNil) + + // Start watcher with file + htw2.Run() + time.Sleep(10 * time.Millisecond) + + // Load file to ensure watcher is active + err = htw2.ChangeFile(htpasswdPath) + So(err, ShouldBeNil) + time.Sleep(10 * time.Millisecond) + + // Close watcher and verify termination + So(htw2.Close(), ShouldBeNil) + // 1 + 1 = 2 + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 2, 5*time.Second), ShouldBeTrue) + + // Test 3: Multiple termination cycles with file watching + for i := 0; i < 3; i++ { + htw2.Run() + time.Sleep(10 * time.Millisecond) + So(htw2.Close(), ShouldBeNil) + time.Sleep(50 * time.Millisecond) // Give time for termination + } + + // Verify we have at least 3 termination messages so far (2 previous + 1 cycle = 3) + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 3, 5*time.Second), ShouldBeTrue) + + // Test 4: Stress test with rapid cycles + for i := 0; i < 5; i++ { + htw2.Run() + time.Sleep(5 * time.Millisecond) + So(htw2.Close(), ShouldBeNil) + time.Sleep(20 * time.Millisecond) // Give time for termination + } + + // Verify we have at least 8 termination messages so far (3+5 = 8) + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 8, 5*time.Second), ShouldBeTrue) + + // Final verification: watcher should still work after all cycles + htw2.Run() + time.Sleep(10 * time.Millisecond) + So(htw2.Close(), ShouldBeNil) + + // Final verification of all termination messages with timeout + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 9, 5*time.Second), ShouldBeTrue) // 8+1 = 9 + }) + + Convey("Test malformed htpasswd files", func() { + // Create a buffer to capture log output + logBuffer, multiWriter := test.CreateLogCapturingWriter(os.Stdout) + capturingLogger := log.NewLoggerWithWriter("debug", multiWriter) + + username, _ := test.GenerateRandomString() + password, _ := test.GenerateRandomString() + + htp := api.NewHTPasswd(capturingLogger) + + // Test file with only colons (malformed) + colonPath := test.MakeHtpasswdFileFromString(":::") + defer os.Remove(colonPath) + htw1, err := api.NewHTPasswdWatcher(htp, colonPath) + So(err, ShouldBeNil) + htw1.Run() + time.Sleep(10 * time.Millisecond) + _ = htw1.ChangeFile(colonPath) + + time.Sleep(10 * time.Millisecond) + // The malformed file creates an entry with empty username, so test that + _, present := htp.Authenticate("", "anything") + So(present, ShouldBeTrue) // Empty username entry exists but auth fails + ok, _ := htp.Authenticate("", "anything") + So(ok, ShouldBeFalse) // But authentication should fail + So(htw1.Close(), ShouldBeNil) + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + + // Test file with empty lines and comments + content := "\n\n" + test.GetCredString(username, password) + "\n# comment\n" + commentedPath := test.MakeHtpasswdFileFromString(content) + + defer os.Remove(commentedPath) + htw2, err := api.NewHTPasswdWatcher(htp, commentedPath) + So(err, ShouldBeNil) + htw2.Run() + time.Sleep(10 * time.Millisecond) + _ = htw2.ChangeFile(commentedPath) + + time.Sleep(10 * time.Millisecond) + ok, _ = htp.Authenticate(username, password) + So(ok, ShouldBeTrue) // User should be loaded (comments/empty lines ignored) + So(htw2.Close(), ShouldBeNil) + // 1 termination message + So(test.WaitForLogMessages(logBuffer, "htpasswd watcher terminating...", 1, 5*time.Second), ShouldBeTrue) + }) + + Convey("Test ChangeFile with nil watcher and empty filepath", func() { + // Create a logger (no need for log capture since we're not testing goroutine termination) + capturingLogger := log.NewLogger("debug", "") + + username, _ := test.GenerateRandomString() + password, _ := test.GenerateRandomString() + + htp := api.NewHTPasswd(capturingLogger) + htw, err := api.NewHTPasswdWatcher(htp, "") + So(err, ShouldBeNil) + + // Load some initial data + htpasswdPath := test.MakeHtpasswdFileFromString(test.GetCredString(username, password)) + defer os.Remove(htpasswdPath) + + // Load initial file (this will populate the store) + err = htw.ChangeFile(htpasswdPath) + So(err, ShouldBeNil) + + // Verify user is loaded + ok, present := htp.Authenticate(username, password) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + + // Now test the edge case: ChangeFile with empty string when watcher is nil + // (watcher is nil because we haven't called Run() yet) + err = htw.ChangeFile("") + So(err, ShouldBeNil) // Should not return an error + + // Verify that the store was cleared + ok, present = htp.Authenticate(username, password) + So(ok, ShouldBeFalse) // Authentication should fail + So(present, ShouldBeFalse) // User should not be present + + // Test that we can still load a file after clearing + err = htw.ChangeFile(htpasswdPath) + So(err, ShouldBeNil) + + // Verify user is loaded again + ok, present = htp.Authenticate(username, password) + So(ok, ShouldBeTrue) + So(present, ShouldBeTrue) + }) + }) +} diff --git a/pkg/api/proxy.go b/pkg/api/proxy.go index 407231ab..5d2ff7e6 100644 --- a/pkg/api/proxy.go +++ b/pkg/api/proxy.go @@ -22,11 +22,12 @@ import ( func ClusterProxy(ctrlr *Controller) func(http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - config := ctrlr.Config + // Get cluster config safely + clusterConfig := ctrlr.Config.CopyClusterConfig() logger := ctrlr.Log // if no cluster or single-node cluster, handle locally. - if config.Cluster == nil || len(config.Cluster.Members) == 1 { + if !clusterConfig.IsClustered() { next.ServeHTTP(response, request) return @@ -45,13 +46,13 @@ func ClusterProxy(ctrlr *Controller) func(http.HandlerFunc) http.HandlerFunc { // the target member is the only one which should do read/write for the dist-spec APIs // for the given repository. - targetMemberIndex, targetMember := cluster.ComputeTargetMember(config.Cluster.HashKey, config.Cluster.Members, name) + targetMemberIndex, targetMember := cluster.ComputeTargetMember(clusterConfig.HashKey, clusterConfig.Members, name) logger.Debug().Str(constants.RepositoryLogKey, name). Msg(fmt.Sprintf("target member socket: %s index: %d", targetMember, targetMemberIndex)) // if the target member is the same as the local member, the current member should handle the request. // since the instances have the same config, a quick index lookup is sufficient - if targetMemberIndex == config.Cluster.Proxy.LocalMemberClusterSocketIndex { + if targetMemberIndex == clusterConfig.Proxy.LocalMemberClusterSocketIndex { logger.Debug().Str(constants.RepositoryLogKey, name).Msg("handling the request locally") next.ServeHTTP(response, request) @@ -119,8 +120,11 @@ func proxyHTTPRequest(ctx context.Context, req *http.Request, ) (*http.Response, error) { cloneURL := *req.URL + // Get HTTP TLS config safely + httpTLSConfig := ctrlr.Config.CopyTLSConfig() + proxyQueryScheme := "http" - if ctrlr.Config.HTTP.TLS != nil { + if httpTLSConfig != nil { proxyQueryScheme = "https" } @@ -142,12 +146,15 @@ func proxyHTTPRequest(ctx context.Context, req *http.Request, fwdRequest.Header.Set(constants.ScaleOutHopCountHeader, "1") clientOpts := common.HTTPClientOptions{ - TLSEnabled: ctrlr.Config.HTTP.TLS != nil, - VerifyTLS: ctrlr.Config.HTTP.TLS != nil, // for now, always verify TLS when TLS mode is enabled + TLSEnabled: httpTLSConfig != nil, + VerifyTLS: httpTLSConfig != nil, // for now, always verify TLS when TLS mode is enabled Host: targetMember, } - tlsConfig := ctrlr.Config.Cluster.TLS + // Get cluster config safely + clusterConfig := ctrlr.Config.CopyClusterConfig() + tlsConfig := clusterConfig.TLS + if tlsConfig != nil { clientOpts.CertOptions.ClientCertFile = tlsConfig.Cert clientOpts.CertOptions.ClientKeyFile = tlsConfig.Key diff --git a/pkg/api/routes.go b/pkg/api/routes.go index a0d24ad8..1655e634 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -73,9 +73,13 @@ func (rh *RouteHandler) SetupRoutes() { // first get Auth middleware in order to first setup openid/ldap/htpasswd, before oidc provider routes are setup authHandler := AuthHandler(rh.c) - applyCORSHeaders := getCORSHeadersHandler(rh.c.Config.HTTP.AllowOrigin) + // Get CORS config safely + allowOrigin := rh.c.Config.GetAllowOrigin() + applyCORSHeaders := getCORSHeadersHandler(allowOrigin) - if rh.c.Config.IsOpenIDAuthEnabled() { + // Get auth config for OpenID checks + authConfig := rh.c.Config.CopyAuthConfig() + if authConfig.IsOpenIDAuthEnabled() { // login path for openID rh.c.Router.HandleFunc(constants.LoginPath, rh.AuthURLHandler()) @@ -91,14 +95,15 @@ func (rh *RouteHandler) SetupRoutes() { } } - if rh.c.Config.IsAPIKeyEnabled() { + // Get auth config for API key checks + if authConfig.IsAPIKeyEnabled() { // enable api key management urls apiKeyRouter := rh.c.Router.PathPrefix(constants.APIKeyPath).Subrouter() apiKeyRouter.Use(authHandler) apiKeyRouter.Use(BaseAuthzHandler(rh.c)) // Always use CORSHeadersMiddleware before ACHeadersMiddleware - apiKeyRouter.Use(zcommon.CORSHeadersMiddleware(rh.c.Config.HTTP.AllowOrigin)) + apiKeyRouter.Use(zcommon.CORSHeadersMiddleware(rh.c.Config.GetAllowOrigin())) apiKeyRouter.Use(zcommon.ACHeadersMiddleware(rh.c.Config, http.MethodGet, http.MethodPost, http.MethodDelete, http.MethodOptions)) @@ -109,7 +114,7 @@ func (rh *RouteHandler) SetupRoutes() { /* on every route which may be used by UI we set OPTIONS as allowed METHOD to enable preflight request from UI to backend */ - if rh.c.Config.IsBasicAuthnEnabled() { + if authConfig.IsBasicAuthnEnabled() { // logout path for openID rh.c.Router.HandleFunc(constants.LogoutPath, getUIHeadersHandler(rh.c.Config, http.MethodPost, http.MethodOptions)(applyCORSHeaders(rh.Logout))). @@ -122,8 +127,9 @@ func (rh *RouteHandler) SetupRoutes() { prefixedDistSpecRouter := prefixedRouter.NewRoute().Subrouter() // authz is being enabled if AccessControl is specified // if Authn is not present AccessControl will have only default policies - if rh.c.Config.HTTP.AccessControl != nil { - if rh.c.Config.IsBasicAuthnEnabled() { + accessControlConfig := rh.c.Config.CopyAccessControlConfig() + if accessControlConfig != nil { + if authConfig.IsBasicAuthnEnabled() { rh.c.Log.Info().Msg("access control is being enabled") } else { rh.c.Log.Info().Msg("anonymous policy only access control is being enabled") @@ -241,7 +247,9 @@ func getUIHeadersHandler(config *config.Config, allowedMethods ...string) func(h response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) - if config.IsBasicAuthnEnabled() { + // Get auth config safely + authConfig := config.CopyAuthConfig() + if authConfig.IsBasicAuthnEnabled() { response.Header().Set("Access-Control-Allow-Credentials", "true") } @@ -267,13 +275,17 @@ func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, reques response.Header().Set(constants.DistAPIVersion, "registry/2.0") // NOTE: compatibility workaround - return this header in "allowed-read" mode to allow for clients to // work correctly - if rh.c.Config.IsBasicAuthnEnabled() || rh.c.Config.IsBearerAuthEnabled() { + // Get auth config safely + authConfig := rh.c.Config.CopyAuthConfig() + if authConfig.IsBasicAuthnEnabled() || authConfig.IsBearerAuthEnabled() { // don't send auth headers if request is coming from UI if request.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue { - if rh.c.Config.HTTP.Auth.Bearer != nil { - response.Header().Set("WWW-Authenticate", "bearer realm="+rh.c.Config.HTTP.Auth.Bearer.Realm) + if authConfig.Bearer != nil { + realm := authConfig.Bearer.Realm + response.Header().Set("WWW-Authenticate", "bearer realm="+realm) } else { - response.Header().Set("WWW-Authenticate", "basic realm="+rh.c.Config.HTTP.Realm) + realm := rh.c.Config.GetRealm() + response.Header().Set("WWW-Authenticate", "basic realm="+realm) } } } @@ -502,7 +514,9 @@ type ExtensionList struct { // @Failure 500 {string} string "internal server error" // @Router /v2/{name}/manifests/{reference} [get]. func (rh *RouteHandler) GetManifest(response http.ResponseWriter, request *http.Request) { - if rh.c.Config.IsBasicAuthnEnabled() { + // Get auth config safely + authConfig := rh.c.Config.CopyAuthConfig() + if authConfig.IsBasicAuthnEnabled() { response.Header().Set("Access-Control-Allow-Credentials", "true") } @@ -694,7 +708,9 @@ func (rh *RouteHandler) UpdateManifest(response http.ResponseWriter, request *ht } mediaType := request.Header.Get("Content-Type") - if !storageCommon.IsSupportedMediaType(rh.c.Config.HTTP.Compat, mediaType) { + compatConfig := rh.c.Config.GetCompat() + + if !storageCommon.IsSupportedMediaType(compatConfig, mediaType) { err := apiErr.NewError(apiErr.MANIFEST_INVALID).AddDetail(map[string]string{"mediaType": mediaType}) zcommon.WriteJSON(response, http.StatusUnsupportedMediaType, apiErr.NewErrorList(err)) @@ -949,7 +965,9 @@ func (rh *RouteHandler) CheckBlob(response http.ResponseWriter, request *http.Re } userCanMount := true - if rh.c.Config.IsAuthzEnabled() { + accessControlConfig := rh.c.Config.CopyAccessControlConfig() + + if accessControlConfig.IsAuthzEnabled() { userCanMount, err = canMount(userAc, imgStore, digest) if err != nil { rh.c.Log.Error().Err(err).Msg("unexpected error") @@ -1265,7 +1283,9 @@ func (rh *RouteHandler) CreateBlobUpload(response http.ResponseWriter, request * } userCanMount := true - if rh.c.Config.IsAuthzEnabled() { + accessControlConfig := rh.c.Config.CopyAccessControlConfig() + + if accessControlConfig.IsAuthzEnabled() { userCanMount, err = canMount(userAc, imgStore, mountDigest) if err != nil { rh.c.Log.Error().Err(err).Msg("unexpected error") @@ -2323,7 +2343,8 @@ func getBlobUploadLocation(url *url.URL, name string, digest godigest.Digest) st } func isSyncOnDemandEnabled(ctlr Controller) bool { - if ctlr.Config.IsSyncEnabled() && + extensionsConfig := ctlr.Config.CopyExtensionsConfig() + if extensionsConfig.IsSyncEnabled() && fmt.Sprintf("%v", ctlr.SyncOnDemand) != fmt.Sprintf("%v", nil) { return true } diff --git a/pkg/cli/server/config_reloader.go b/pkg/cli/server/config_reloader.go index 3ddb9aed..9121030d 100644 --- a/pkg/cli/server/config_reloader.go +++ b/pkg/cli/server/config_reloader.go @@ -85,9 +85,10 @@ func (hr *HotReloader) Start() { continue } - if hr.ctlr.Config.HTTP.Auth != nil && hr.ctlr.Config.HTTP.Auth.LDAP != nil && - hr.ctlr.Config.HTTP.Auth.LDAP.CredentialsFile != newConfig.HTTP.Auth.LDAP.CredentialsFile { - err = hr.watcher.Remove(hr.ctlr.Config.HTTP.Auth.LDAP.CredentialsFile) + authConfig := hr.ctlr.Config.CopyAuthConfig() + if authConfig.IsLdapAuthEnabled() && + authConfig.LDAP.CredentialsFile != newConfig.HTTP.Auth.LDAP.CredentialsFile { + err = hr.watcher.Remove(authConfig.LDAP.CredentialsFile) if err != nil && !errors.Is(err, fsnotify.ErrNonExistentWatch) { hr.logger.Error().Err(err).Msg("failed to remove old watch for the credentials file") } diff --git a/pkg/cli/server/extensions_test.go b/pkg/cli/server/extensions_test.go index 385a9107..ecb0798d 100644 --- a/pkg/cli/server/extensions_test.go +++ b/pkg/cli/server/extensions_test.go @@ -883,7 +883,7 @@ func TestServeScrubExtension(t *testing.T) { // Even if in config we specified scrub interval=1h, the minimum interval is 2h dataStr := string(data) - So(dataStr, ShouldContainSubstring, "\"Scrub\":{\"Enable\":true,\"Interval\":3600000000000}") + So(dataStr, ShouldContainSubstring, "\"Scrub\":{\"Enable\":true,\"Interval\":7200000000000}") So(dataStr, ShouldContainSubstring, "scrub interval set to too-short interval < 2h, changing scrub duration to 2 hours and continuing.") }) diff --git a/pkg/cli/server/root.go b/pkg/cli/server/root.go index 23263b33..a9b1316d 100644 --- a/pkg/cli/server/root.go +++ b/pkg/cli/server/root.go @@ -202,8 +202,9 @@ func NewServerRootCmd() *cobra.Command { logger := zlog.NewLogger("info", "") if showVersion { - logger.Info().Str("distribution-spec", distspec.Version).Str("commit", config.Commit). - Str("binary-type", config.BinaryType).Str("go version", config.GoVersion).Msg("version") + commit, binaryType, goVersion, _ := conf.GetVersionInfo() + logger.Info().Str("distribution-spec", distspec.Version).Str("commit", commit). + Str("binary-type", binaryType).Str("go version", goVersion).Msg("version") } else { _ = cmd.Usage() cmd.SilenceErrors = false @@ -228,19 +229,20 @@ func NewServerRootCmd() *cobra.Command { func validateStorageConfig(cfg *config.Config, logger zlog.Logger) error { expConfigMap := make(map[string]config.StorageConfig, 0) - defaultRootDir := cfg.Storage.RootDirectory + storageConfig := cfg.CopyStorageConfig() + defaultRootDir := storageConfig.RootDirectory - for _, storageConfig := range cfg.Storage.SubPaths { - if strings.EqualFold(defaultRootDir, storageConfig.RootDirectory) { + for _, subStorageConfig := range storageConfig.SubPaths { + if strings.EqualFold(defaultRootDir, subStorageConfig.RootDirectory) { msg := "invalid storage config, storage subpaths cannot use default storage root directory" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } - expConfig, ok := expConfigMap[storageConfig.RootDirectory] + expConfig, ok := expConfigMap[subStorageConfig.RootDirectory] if ok { - equal := expConfig.ParamsEqual(storageConfig) + equal := expConfig.ParamsEqual(subStorageConfig) if !equal { msg := "invalid storage config, storage config with same root directory should have same parameters" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -248,7 +250,7 @@ func validateStorageConfig(cfg *config.Config, logger zlog.Logger) error { return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } } else { - expConfigMap[storageConfig.RootDirectory] = storageConfig + expConfigMap[subStorageConfig.RootDirectory] = subStorageConfig } } @@ -257,20 +259,21 @@ func validateStorageConfig(cfg *config.Config, logger zlog.Logger) error { func validateCacheConfig(cfg *config.Config, logger zlog.Logger) error { // global + storageConfig := cfg.CopyStorageConfig() // dedupe true, remote storage, remoteCache true, but no cacheDriver (remote) //nolint: lll - if cfg.Storage.Dedupe && cfg.Storage.StorageDriver != nil && cfg.Storage.RemoteCache && cfg.Storage.CacheDriver == nil { + if storageConfig.Dedupe && storageConfig.StorageDriver != nil && storageConfig.RemoteCache && storageConfig.CacheDriver == nil { msg := "invalid database config, dedupe set to true with remote storage and database, but no remote database configured" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } - if cfg.Storage.CacheDriver != nil && cfg.Storage.RemoteCache { + if storageConfig.CacheDriver != nil && storageConfig.RemoteCache { // local storage with remote database // redis is supported with both local and S3 storage, while dynamodb is only supported with S3 // redis is only supported with local storage in a non-clustering scenario with a single zot instance, - if cfg.Storage.StorageDriver == nil && cfg.Storage.CacheDriver["name"] != storageConstants.RedisDriverName { + if storageConfig.StorageDriver == nil && storageConfig.CacheDriver["name"] != storageConstants.RedisDriverName { msg := "invalid database config, cannot have local storage driver with remote database!" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -278,23 +281,23 @@ func validateCacheConfig(cfg *config.Config, logger zlog.Logger) error { } // unsupported database driver - if cfg.Storage.CacheDriver["name"] != storageConstants.DynamoDBDriverName && - cfg.Storage.CacheDriver["name"] != storageConstants.RedisDriverName { + if storageConfig.CacheDriver["name"] != storageConstants.DynamoDBDriverName && + storageConfig.CacheDriver["name"] != storageConstants.RedisDriverName { msg := "invalid database config, unsupported database driver" - logger.Error().Err(zerr.ErrBadConfig).Interface("cacheDriver", cfg.Storage.CacheDriver["name"]).Msg(msg) + logger.Error().Err(zerr.ErrBadConfig).Interface("cacheDriver", storageConfig.CacheDriver["name"]).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } } - if !cfg.Storage.RemoteCache && cfg.Storage.CacheDriver != nil { - logger.Warn().Err(zerr.ErrBadConfig).Str("directory", cfg.Storage.RootDirectory). + if !storageConfig.RemoteCache && storageConfig.CacheDriver != nil { + logger.Warn().Err(zerr.ErrBadConfig).Str("directory", storageConfig.RootDirectory). Msg("invalid database config, remoteCache set to false but cacheDriver config (remote database)" + " provided for directory will ignore and use local database") } // subpaths - for _, subPath := range cfg.Storage.SubPaths { + for _, subPath := range storageConfig.SubPaths { // dedupe true, remote storage, remoteCache true, but no cacheDriver (remote) //nolint: lll if subPath.Dedupe && subPath.StorageDriver != nil && subPath.RemoteCache && subPath.CacheDriver == nil { @@ -316,14 +319,14 @@ func validateCacheConfig(cfg *config.Config, logger zlog.Logger) error { // unsupported cache driver if subPath.CacheDriver["name"] != storageConstants.DynamoDBDriverName { msg := "invalid database config, unsupported database driver" - logger.Error().Err(zerr.ErrBadConfig).Interface("cacheDriver", cfg.Storage.CacheDriver["name"]).Msg(msg) + logger.Error().Err(zerr.ErrBadConfig).Interface("cacheDriver", subPath.CacheDriver["name"]).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } } if !subPath.RemoteCache && subPath.CacheDriver != nil { - logger.Warn().Err(zerr.ErrBadConfig).Str("directory", cfg.Storage.RootDirectory). + logger.Warn().Err(zerr.ErrBadConfig).Str("directory", subPath.RootDirectory). Msg("invalid database config, remoteCache set to false but cacheDriver config (remote database)" + "provided for directory, will ignore and use local database") } @@ -335,11 +338,12 @@ func validateCacheConfig(cfg *config.Config, logger zlog.Logger) error { func validateRemoteSessionStoreConfig(cfg *config.Config, logger zlog.Logger) error { // it is okay for the session driver config to be nil // this is backwards compatible for older configs - if cfg.HTTP.Auth.SessionDriver == nil { + authConfig := cfg.CopyAuthConfig() + if authConfig == nil || authConfig.SessionDriver == nil { return nil } - sessionDriverName, ok := cfg.HTTP.Auth.SessionDriver["name"] + sessionDriverName, ok := authConfig.SessionDriver["name"] if !ok { msg := "must provide session driver name!" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -373,7 +377,7 @@ func validateRemoteSessionStoreConfig(cfg *config.Config, logger zlog.Logger) er // as redis session store does not support these yet. if sessionDriverName == storageConstants.RedisDriverName { - if cfg.HTTP.Auth.SessionKeysFile != "" { + if authConfig.SessionKeysFile != "" { msg := "session keys not supported when redis session driver is used!" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -385,19 +389,20 @@ func validateRemoteSessionStoreConfig(cfg *config.Config, logger zlog.Logger) er } func validateExtensionsConfig(cfg *config.Config, logger zlog.Logger) error { - if cfg.Extensions != nil && cfg.Extensions.Mgmt != nil { + extensionsConfig := cfg.CopyExtensionsConfig() + if extensionsConfig != nil && extensionsConfig.Mgmt != nil { logger.Warn().Msg("mgmt extensions configuration option has been made redundant and will be ignored.") } - if cfg.Extensions != nil && cfg.Extensions.APIKey != nil { + if extensionsConfig != nil && extensionsConfig.APIKey != nil { logger.Warn().Msg("apikey extension configuration will be ignored as API keys " + "are now configurable in the HTTP settings.") } - if cfg.Extensions != nil && cfg.Extensions.UI != nil && cfg.Extensions.UI.Enable != nil && *cfg.Extensions.UI.Enable { + if extensionsConfig.IsUIEnabled() { // it would make sense to also check for mgmt and user prefs to be enabled, // but those are both enabled by having the search and ui extensions enabled - if cfg.Extensions.Search == nil || !*cfg.Extensions.Search.Enable { + if !extensionsConfig.IsSearchEnabled() { msg := "failed to enable ui, search extension must be enabled" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -406,18 +411,17 @@ func validateExtensionsConfig(cfg *config.Config, logger zlog.Logger) error { } //nolint:lll - if cfg.Storage.StorageDriver != nil && cfg.Extensions != nil && cfg.Extensions.Search != nil && - cfg.Extensions.Search.Enable != nil && *cfg.Extensions.Search.Enable && cfg.Extensions.Search.CVE != nil { + storageConfig := cfg.CopyStorageConfig() + if storageConfig.StorageDriver != nil && extensionsConfig.IsCveScanningEnabled() { msg := "failed to enable cve scanning due to incompatibility with remote storage, please disable cve" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } - for _, subPath := range cfg.Storage.SubPaths { + for _, subPath := range storageConfig.SubPaths { //nolint:lll - if subPath.StorageDriver != nil && cfg.Extensions != nil && cfg.Extensions.Search != nil && - cfg.Extensions.Search.Enable != nil && *cfg.Extensions.Search.Enable && cfg.Extensions.Search.CVE != nil { + if subPath.StorageDriver != nil && extensionsConfig.IsCveScanningEnabled() { msg := "failed to enable cve scanning due to incompatibility with remote storage, please disable cve" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -466,24 +470,27 @@ func validateConfiguration(config *config.Config, logger zlog.Logger) error { } // check authorization config, it should have basic auth enabled or ldap, api keys or OpenID - if config.HTTP.AccessControl != nil { + accessControlConfig := config.CopyAccessControlConfig() + if accessControlConfig != nil { // checking for anonymous policy only authorization config: no users, no policies but anonymous policy if err := validateAuthzPolicies(config, logger); err != nil { return err } } - if len(config.Storage.StorageDriver) != 0 { + storageConfig := config.CopyStorageConfig() + if len(storageConfig.StorageDriver) != 0 { // enforce s3 driver in case of using storage driver - if config.Storage.StorageDriver["name"] != storageConstants.S3StorageDriverName { + if storageConfig.StorageDriver["name"] != storageConstants.S3StorageDriverName { msg := "unsupported storage driver" - logger.Error().Err(zerr.ErrBadConfig).Interface("cacheDriver", config.Storage.StorageDriver["name"]).Msg(msg) + logger.Error().Err(zerr.ErrBadConfig).Interface("cacheDriver", storageConfig.StorageDriver["name"]).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } // enforce tmpDir in case sync + s3 - if config.Extensions != nil && config.Extensions.Sync != nil && config.Extensions.Sync.DownloadDir == "" { + extensionsConfig := config.CopyExtensionsConfig() + if extensionsConfig.IsSyncEnabled() && extensionsConfig.Sync.DownloadDir == "" { msg := "using both sync and remote storage features needs config.Extensions.Sync.DownloadDir to be specified" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -492,22 +499,23 @@ func validateConfiguration(config *config.Config, logger zlog.Logger) error { } // enforce s3 driver on subpaths in case of using storage driver - if config.Storage.SubPaths != nil { - if len(config.Storage.SubPaths) > 0 { - subPaths := config.Storage.SubPaths + if storageConfig.SubPaths != nil { + if len(storageConfig.SubPaths) > 0 { + subPaths := storageConfig.SubPaths - for route, storageConfig := range subPaths { - if len(storageConfig.StorageDriver) != 0 { - if storageConfig.StorageDriver["name"] != storageConstants.S3StorageDriverName { + for route, subStorageConfig := range subPaths { + if len(subStorageConfig.StorageDriver) != 0 { + if subStorageConfig.StorageDriver["name"] != storageConstants.S3StorageDriverName { msg := "unsupported storage driver" logger.Error().Err(zerr.ErrBadConfig).Str("subpath", route).Interface("storageDriver", - storageConfig.StorageDriver["name"]).Msg(msg) + subStorageConfig.StorageDriver["name"]).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } // enforce tmpDir in case sync + s3 - if config.Extensions != nil && config.Extensions.Sync != nil && config.Extensions.Sync.DownloadDir == "" { + extensionsConfig := config.CopyExtensionsConfig() + if extensionsConfig.IsSyncEnabled() && extensionsConfig.Sync.DownloadDir == "" { msg := "using both sync and remote storage features needs config.Extensions.Sync.DownloadDir to be specified" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -519,8 +527,8 @@ func validateConfiguration(config *config.Config, logger zlog.Logger) error { } // check glob patterns in authz config are compilable - if config.HTTP.AccessControl != nil { - for pattern := range config.HTTP.AccessControl.Repositories { + if accessControlConfig != nil { + for pattern := range accessControlConfig.Repositories { ok := glob.ValidatePattern(pattern) if !ok { msg := "failed to compile authorization pattern" @@ -540,8 +548,10 @@ func validateConfiguration(config *config.Config, logger zlog.Logger) error { } func validateOpenIDConfig(cfg *config.Config, logger zlog.Logger) error { - if cfg.HTTP.Auth != nil && cfg.HTTP.Auth.OpenID != nil { - for provider, providerConfig := range cfg.HTTP.Auth.OpenID.Providers { + authConfig := cfg.CopyAuthConfig() + // can't check with IsOpenIDAuthEnabled(), because it can't test invalid providers + if authConfig != nil && authConfig.OpenID != nil && len(authConfig.OpenID.Providers) > 0 { + for provider, providerConfig := range authConfig.OpenID.Providers { //nolint: gocritic if config.IsOpenIDSupported(provider) { if providerConfig.ClientID == "" || providerConfig.Issuer == "" || @@ -571,8 +581,12 @@ func validateOpenIDConfig(cfg *config.Config, logger zlog.Logger) error { } func validateAuthzPolicies(config *config.Config, logger zlog.Logger) error { - if (config.HTTP.Auth == nil || (config.HTTP.Auth.HTPasswd.Path == "" && config.HTTP.Auth.LDAP == nil && - config.HTTP.Auth.OpenID == nil)) && !authzContainsOnlyAnonymousPolicy(config) { + authConfig := config.CopyAuthConfig() + accessControlConfig := config.CopyAccessControlConfig() + + logger.Info().Msg("checking if anonymous authorization is the only type of authorization policy configured") + + if !authConfig.IsBasicAuthnEnabled() && !accessControlConfig.ContainsOnlyAnonymousPolicy() { msg := "access control config requires one of httpasswd, ldap or openid authentication " + "or using only 'anonymousPolicy' policies" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -696,6 +710,15 @@ func applyDefaultValues(config *config.Config, viperInstance *viper.Viper, logge if config.Extensions.Scrub.Interval == 0 { config.Extensions.Scrub.Interval = 24 * time.Hour //nolint:mnd } + + // Validate minimum scrub interval + minScrubInterval, _ := time.ParseDuration("2h") + if config.Extensions.Scrub.Interval < minScrubInterval { + config.Extensions.Scrub.Interval = minScrubInterval + + logger.Warn().Msg("scrub interval set to too-short interval < 2h, " + + "changing scrub duration to 2 hours and continuing.") + } } if config.Extensions.UI != nil { @@ -1076,52 +1099,11 @@ func readSecretFile(path string, v any, checkUnsetFields bool) error { //nolint: return nil } -func authzContainsOnlyAnonymousPolicy(cfg *config.Config) bool { - logger := zlog.NewLogger("info", "") - - adminPolicy := cfg.HTTP.AccessControl.AdminPolicy - anonymousPolicyPresent := false - - logger.Info().Msg("checking if anonymous authorization is the only type of authorization policy configured") - - if len(adminPolicy.Actions)+len(adminPolicy.Users) > 0 { - logger.Info().Msg("admin policy detected, anonymous authorization is not the only authorization policy configured") - - return false - } - - for _, repository := range cfg.HTTP.AccessControl.Repositories { - if len(repository.DefaultPolicy) > 0 { - logger.Info().Interface("repository", repository). - Msg("default policy detected, anonymous authorization is not the only authorization policy configured") - - return false - } - - if len(repository.AnonymousPolicy) > 0 { - logger.Info().Msg("anonymous authorization detected") - - anonymousPolicyPresent = true - } - - for _, policy := range repository.Policies { - if len(policy.Actions)+len(policy.Users) > 0 { - logger.Info().Interface("repository", repository). - Msg("repository with non-empty policy detected, " + - "anonymous authorization is not the only authorization policy configured") - - return false - } - } - } - - return anonymousPolicyPresent -} - func validateLDAP(config *config.Config, logger zlog.Logger) error { // LDAP mandatory configuration - if config.HTTP.Auth != nil && config.HTTP.Auth.LDAP != nil { - ldap := config.HTTP.Auth.LDAP + authConfig := config.CopyAuthConfig() + if authConfig.IsLdapAuthEnabled() { + ldap := authConfig.LDAP if ldap.UserAttribute == "" { msg := "invalid LDAP configuration, missing mandatory key: userAttribute" logger.Error().Str("userAttribute", ldap.UserAttribute).Msg(msg) @@ -1148,12 +1130,13 @@ func validateLDAP(config *config.Config, logger zlog.Logger) error { } func validateHTTP(config *config.Config, logger zlog.Logger) error { - if config.HTTP.Port != "" { - port, err := strconv.ParseInt(config.HTTP.Port, 10, 64) - if err != nil || (port < 0 || port > 65535) { - logger.Error().Str("port", config.HTTP.Port).Msg("invalid port") + port := config.GetHTTPPort() + if port != "" { + portInt, err := strconv.ParseInt(port, 10, 64) + if err != nil || (portInt < 0 || portInt > 65535) { + logger.Error().Str("port", port).Msg("invalid port") - return fmt.Errorf("%w: invalid port %s", zerr.ErrBadConfig, config.HTTP.Port) + return fmt.Errorf("%w: invalid port %s", zerr.ErrBadConfig, port) } } @@ -1162,40 +1145,41 @@ func validateHTTP(config *config.Config, logger zlog.Logger) error { func validateGC(config *config.Config, logger zlog.Logger) error { // enforce GC params - if config.Storage.GCDelay < 0 { - logger.Error().Err(zerr.ErrBadConfig).Dur("delay", config.Storage.GCDelay). + storageConfig := config.CopyStorageConfig() + if storageConfig.GCDelay < 0 { + logger.Error().Err(zerr.ErrBadConfig).Dur("delay", storageConfig.GCDelay). Msg("invalid garbage-collect delay specified") return fmt.Errorf("%w: invalid garbage-collect delay specified %s", - zerr.ErrBadConfig, config.Storage.GCDelay) + zerr.ErrBadConfig, storageConfig.GCDelay) } - if config.Storage.GCInterval < 0 { - logger.Error().Err(zerr.ErrBadConfig).Dur("interval", config.Storage.GCInterval). + if storageConfig.GCInterval < 0 { + logger.Error().Err(zerr.ErrBadConfig).Dur("interval", storageConfig.GCInterval). Msg("invalid garbage-collect interval specified") return fmt.Errorf("%w: invalid garbage-collect interval specified %s", - zerr.ErrBadConfig, config.Storage.GCInterval) + zerr.ErrBadConfig, storageConfig.GCInterval) } - if !config.Storage.GC { - if config.Storage.GCDelay != 0 { + if !storageConfig.GC { + if storageConfig.GCDelay != 0 { logger.Warn().Err(zerr.ErrBadConfig). Msg("garbage-collect delay specified without enabling garbage-collect, will be ignored") } - if config.Storage.GCInterval != 0 { + if storageConfig.GCInterval != 0 { logger.Warn().Err(zerr.ErrBadConfig). Msg("periodic garbage-collect interval specified without enabling garbage-collect, will be ignored") } } - if err := validateGCRules(config.Storage.Retention, logger); err != nil { + if err := validateGCRules(storageConfig.Retention, logger); err != nil { return err } // subpaths - for name, subPath := range config.Storage.SubPaths { + for name, subPath := range storageConfig.SubPaths { if subPath.GC && subPath.GCDelay <= 0 { logger.Error().Err(zerr.ErrBadConfig). Str("subPath", name). @@ -1245,13 +1229,15 @@ func validateGCRules(retention config.ImageRetention, logger zlog.Logger) error func validateSync(config *config.Config, logger zlog.Logger) error { // check glob patterns in sync config are compilable - if config.Extensions != nil && config.Extensions.Sync != nil { - for regID, regCfg := range config.Extensions.Sync.Registries { + extensionsConfig := config.CopyExtensionsConfig() + // can't check with IsSyncEnabled(), because it can't test invalid sync configs + if extensionsConfig != nil && extensionsConfig.Sync != nil && len(extensionsConfig.Sync.Registries) > 0 { + for regID, regCfg := range extensionsConfig.Sync.Registries { // check retry options are configured for sync if regCfg.MaxRetries != nil && regCfg.RetryDelay == nil { msg := "retryDelay is required when using maxRetries" logger.Error().Err(zerr.ErrBadConfig).Int("id", regID).Interface("extensions.sync.registries[id]", - config.Extensions.Sync.Registries[regID]).Msg(msg) + extensionsConfig.Sync.Registries[regID]).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } @@ -1260,7 +1246,7 @@ func validateSync(config *config.Config, logger zlog.Logger) error { if regCfg.PreserveDigest && !config.IsCompatEnabled() { msg := "can not use PreserveDigest option without enabling http.Compat" logger.Error().Err(zerr.ErrBadConfig).Int("id", regID).Interface("extensions.sync.registries[id]", - config.Extensions.Sync.Registries[regID]).Msg(msg) + extensionsConfig.Sync.Registries[regID]).Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) } @@ -1314,8 +1300,9 @@ func validateSync(config *config.Config, logger zlog.Logger) error { } func validateClusterConfig(config *config.Config, logger zlog.Logger) error { - if config.Cluster != nil { - if len(config.Cluster.Members) == 0 { + clusterConfig := config.CopyClusterConfig() + if clusterConfig != nil { + if len(clusterConfig.Members) == 0 { msg := "cannot have 0 members in a scale out cluster" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) @@ -1325,10 +1312,10 @@ func validateClusterConfig(config *config.Config, logger zlog.Logger) error { // the allowed length is 16 as the siphash requires a 128 bit key. // that translates to 16 characters * 8 bits each. allowedHashKeyLength := 16 - if len(config.Cluster.HashKey) != allowedHashKeyLength { + if len(clusterConfig.HashKey) != allowedHashKeyLength { msg := fmt.Sprintf("hashKey for scale out cluster must have %d characters", allowedHashKeyLength) logger.Error().Err(zerr.ErrBadConfig). - Str("hashkey", config.Cluster.HashKey). + Str("hashkey", clusterConfig.HashKey). Msg(msg) return fmt.Errorf("%w: %s", zerr.ErrBadConfig, msg) diff --git a/pkg/common/http_server.go b/pkg/common/http_server.go index 69b57fca..e29edf0d 100644 --- a/pkg/common/http_server.go +++ b/pkg/common/http_server.go @@ -38,7 +38,9 @@ func ACHeadersMiddleware(config *config.Config, allowedMethods ...string) mux.Mi resp.Header().Set("Access-Control-Allow-Methods", allowedMethodsValue) resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) - if config.IsBasicAuthnEnabled() { + // Get auth config safely + authConfig := config.CopyAuthConfig() + if authConfig.IsBasicAuthnEnabled() { resp.Header().Set("Access-Control-Allow-Credentials", "true") } @@ -73,23 +75,28 @@ func AddCORSHeaders(allowOrigin string, response http.ResponseWriter) { func AuthzOnlyAdminsMiddleware(conf *config.Config) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - if !conf.IsBasicAuthnEnabled() { + // Get auth config safely + authConfig := conf.CopyAuthConfig() + if !authConfig.IsBasicAuthnEnabled() { next.ServeHTTP(response, request) return } + realm := conf.GetRealm() + failDelay := authConfig.GetFailDelay() + // get userAccessControl built in previous authn/authz middlewares userAc, err := reqCtx.UserAcFromContext(request.Context()) if err != nil { // should not happen as this has been previously checked for errors - AuthzFail(response, request, userAc.GetUsername(), conf.HTTP.Realm, conf.HTTP.Auth.FailDelay) + AuthzFail(response, request, userAc.GetUsername(), realm, failDelay) return } // reject non-admin access if authentication is enabled if userAc != nil && !userAc.IsAdmin() { - AuthzFail(response, request, userAc.GetUsername(), conf.HTTP.Realm, conf.HTTP.Auth.FailDelay) + AuthzFail(response, request, userAc.GetUsername(), realm, failDelay) return } diff --git a/pkg/extensions/config/config.go b/pkg/extensions/config/config.go index 8f139433..20bf8b7f 100644 --- a/pkg/extensions/config/config.go +++ b/pkg/extensions/config/config.go @@ -46,12 +46,11 @@ type LintConfig struct { type SearchConfig struct { BaseConfig `mapstructure:",squash"` - // CVE search - CVE *CVEConfig + CVE *CVEConfig } type CVEConfig struct { - UpdateInterval time.Duration // should be 2 hours or more, if not specified default be kept as 24 hours + UpdateInterval time.Duration // should be 2 hours or more, if not specified default be kept as 2 hours Trivy *TrivyConfig } @@ -77,3 +76,166 @@ type ScrubConfig struct { type UIConfig struct { BaseConfig `mapstructure:",squash"` } + +// isSearchEnabledInternal checks if search is enabled (internal use only). +func (e *ExtensionConfig) isSearchEnabledInternal() bool { + return e != nil && e.Search != nil && e.Search.Enable != nil && *e.Search.Enable +} + +// isUIEnabledInternal checks if UI is enabled (internal use only). +func (e *ExtensionConfig) isUIEnabledInternal() bool { + return e != nil && e.UI != nil && e.UI.Enable != nil && *e.UI.Enable +} + +// IsCveScanningEnabled checks if CVE scanning is enabled in this extensions config. +func (e *ExtensionConfig) IsCveScanningEnabled() bool { + if e == nil { + return false + } + + return e.Search != nil && e.Search.Enable != nil && *e.Search.Enable && + e.Search.CVE != nil && e.Search.CVE.Trivy != nil +} + +// IsEventRecorderEnabled checks if event recording is enabled in this extensions config. +func (e *ExtensionConfig) IsEventRecorderEnabled() bool { + if e == nil { + return false + } + + return e.Events != nil && e.Events.Enable != nil && *e.Events.Enable +} + +// IsSearchEnabled checks if search is enabled in this extensions config. +func (e *ExtensionConfig) IsSearchEnabled() bool { + return e.isSearchEnabledInternal() +} + +// IsSyncEnabled checks if sync is enabled in this extensions config. +func (e *ExtensionConfig) IsSyncEnabled() bool { + if e == nil { + return false + } + + // Sync is enabled if either: + // 1. Explicitly enabled (Enable == true), OR + // 2. There are registries configured (enabled by default when registries exist) + // This matches the behavior in root.go where Sync.Enable defaults to true when registries are present + return e.Sync != nil && ((e.Sync.Enable != nil && *e.Sync.Enable) || len(e.Sync.Registries) > 0) +} + +// IsScrubEnabled checks if scrub is enabled in this extensions config. +func (e *ExtensionConfig) IsScrubEnabled() bool { + if e == nil { + return false + } + + return e.Scrub != nil && e.Scrub.Enable != nil && *e.Scrub.Enable +} + +// IsMetricsEnabled checks if metrics are enabled in this extensions config. +func (e *ExtensionConfig) IsMetricsEnabled() bool { + if e == nil { + return false + } + + return e.Metrics != nil && e.Metrics.Enable != nil && *e.Metrics.Enable +} + +// IsCosignEnabled checks if Cosign is enabled in this extensions config. +func (e *ExtensionConfig) IsCosignEnabled() bool { + if e == nil { + return false + } + + return e.Trust != nil && e.Trust.Enable != nil && *e.Trust.Enable && e.Trust.Cosign +} + +// IsNotationEnabled checks if Notation is enabled in this extensions config. +func (e *ExtensionConfig) IsNotationEnabled() bool { + if e == nil { + return false + } + + return e.Trust != nil && e.Trust.Enable != nil && *e.Trust.Enable && e.Trust.Notation +} + +// IsImageTrustEnabled checks if image trust is enabled in this extensions config. +func (e *ExtensionConfig) IsImageTrustEnabled() bool { + if e == nil { + return false + } + + return e.Trust != nil && e.Trust.Enable != nil && *e.Trust.Enable +} + +// IsUIEnabled checks if UI is enabled in this extensions config. +func (e *ExtensionConfig) IsUIEnabled() bool { + return e.isUIEnabledInternal() +} + +// AreUserPrefsEnabled checks if user preferences are enabled in this extensions config. +func (e *ExtensionConfig) AreUserPrefsEnabled() bool { + if e == nil { + return false + } + + return e.isSearchEnabledInternal() && e.isUIEnabledInternal() +} + +// GetSearchCVEConfig returns the search CVE config. +func (e *ExtensionConfig) GetSearchCVEConfig() *CVEConfig { + if e == nil { + return nil + } + + if e.Search != nil { + return e.Search.CVE + } + + return nil +} + +// GetScrubInterval returns the scrub interval. +func (e *ExtensionConfig) GetScrubInterval() time.Duration { + if e == nil { + return 0 + } + + if e.Scrub != nil { + return e.Scrub.Interval + } + + return 0 +} + +// GetSyncConfig returns the sync config. +func (e *ExtensionConfig) GetSyncConfig() *sync.Config { + if e == nil { + return nil + } + + return e.Sync +} + +// GetMetricsPrometheusConfig returns the metrics prometheus config. +func (e *ExtensionConfig) GetMetricsPrometheusConfig() *PrometheusConfig { + if e == nil { + return nil + } + + if e.Metrics != nil { + return e.Metrics.Prometheus + } + + return nil +} + +// GetEventsConfig returns the events config. +func (e *ExtensionConfig) GetEventsConfig() *events.Config { + if e == nil { + return nil + } + + return e.Events +} diff --git a/pkg/extensions/config/config_test.go b/pkg/extensions/config/config_test.go new file mode 100644 index 00000000..a43069dd --- /dev/null +++ b/pkg/extensions/config/config_test.go @@ -0,0 +1,564 @@ +package config_test + +import ( + "errors" + "testing" + + . "github.com/smartystreets/goconvey/convey" + + "zotregistry.dev/zot/v2/pkg/extensions/config" + "zotregistry.dev/zot/v2/pkg/extensions/config/events" + "zotregistry.dev/zot/v2/pkg/extensions/config/sync" +) + +var ( + errIsSearchEnabledExpectedTrue = errors.New("expected IsSearchEnabled to return true, got false") + errIsUIEnabledExpectedTrue = errors.New("expected IsUIEnabled to return true, got false") + errAreUserPrefsEnabledExpectedTrue = errors.New("expected AreUserPrefsEnabled to return true, got false") + errPanicRecovered = errors.New("panic recovered") +) + +// Config builder functions for different extension types. +func buildSearchConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Search = &config.SearchConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + } + + return ext +} + +func buildSearchConfigWithCVE(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Search = &config.SearchConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + CVE: &config.CVEConfig{ + Trivy: &config.TrivyConfig{}, + }, + } + + return ext +} + +func buildEventsConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Events = &events.Config{ + Enable: &enabled, + } + + return ext +} + +func buildSyncConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Sync = &sync.Config{ + Enable: &enabled, + } + + return ext +} + +func buildScrubConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Scrub = &config.ScrubConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + } + + return ext +} + +func buildMetricsConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Metrics = &config.MetricsConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + Prometheus: &config.PrometheusConfig{ + Path: "/metrics", + }, + } + + return ext +} + +func buildTrustConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Trust = &config.ImageTrustConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + } + + return ext +} + +func buildUIConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.UI = &config.UIConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + } + + return ext +} + +func buildSearchAndUIConfig(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Search = &config.SearchConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + } + ext.UI = &config.UIConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + } + + return ext +} + +func buildTrustConfigWithCosign(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Trust = &config.ImageTrustConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + Cosign: true, + } + + return ext +} + +func buildTrustConfigWithNotation(enabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Trust = &config.ImageTrustConfig{ + BaseConfig: config.BaseConfig{ + Enable: &enabled, + }, + Notation: true, + } + + return ext +} + +// Test helper functions to reduce code duplication + +// testMethodWithNilConfig tests a method with nil ExtensionConfig. +func testMethodWithNilConfig(testFunc func(*config.ExtensionConfig) bool) { + Convey("Test with nil ExtensionConfig", func() { + var extensionConfig *config.ExtensionConfig = nil + + So(testFunc(extensionConfig), ShouldBeFalse) + }) +} + +// testMethodWithNilSubConfig tests a method when ExtensionConfig exists but the relevant sub-config is nil. +func testMethodWithNilSubConfig(subConfigName string, testFunc func(*config.ExtensionConfig) bool) { + Convey("Test with ExtensionConfig but nil "+subConfigName, func() { + extensionConfig := &config.ExtensionConfig{} + + So(testFunc(extensionConfig), ShouldBeFalse) + }) +} + +// testMethodWithNilEnable tests a method when ExtensionConfig and sub-config exist but Enable is nil. +func testMethodWithNilEnable(subConfigName string, testFunc func(*config.ExtensionConfig) bool) { + Convey("Test with ExtensionConfig and "+subConfigName+" but nil Enable", func() { + extensionConfig := &config.ExtensionConfig{} + + So(testFunc(extensionConfig), ShouldBeFalse) + }) +} + +// testMethodWithDisabledEnable tests a method when Enable is explicitly set to false. +func testMethodWithDisabledEnable( + subConfigName string, + testFunc func(*config.ExtensionConfig) bool, + configBuilder func(bool) *config.ExtensionConfig, +) { + Convey("Test with ExtensionConfig and "+subConfigName+" and Enable but disabled", func() { + disabled := false + extensionConfig := configBuilder(disabled) + So(testFunc(extensionConfig), ShouldBeFalse) + }) +} + +// testMethodWithEnabledEnable tests a method when Enable is explicitly set to true. +func testMethodWithEnabledEnable( + subConfigName string, + testFunc func(*config.ExtensionConfig) bool, + configBuilder func(bool) *config.ExtensionConfig, +) { + Convey("Test with ExtensionConfig and "+subConfigName+" and Enable enabled", func() { + enabled := true + extensionConfig := configBuilder(enabled) + So(testFunc(extensionConfig), ShouldBeTrue) + }) +} + +// testConcurrentAccessWithConfig tests concurrent access to a method with a properly configured ExtensionConfig. +func testConcurrentAccessWithConfig( + methodName string, + testFunc func(*config.ExtensionConfig) bool, + expectedError error, + extensionConfig *config.ExtensionConfig, +) { + Convey("Test concurrent access to "+methodName, func() { + // Test concurrent access to verify thread-safety + done := make(chan bool, 10) + errors := make(chan error, 10) + + for i := 0; i < 10; i++ { + go func() { + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + errors <- err + } else { + errors <- errPanicRecovered + } + } + done <- true + }() + + for j := 0; j < 100; j++ { + result := testFunc(extensionConfig) + if !result { + errors <- expectedError + + return + } + } + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + + // Check for errors + close(errors) + + for err := range errors { + So(err, ShouldBeNil) + } + }) +} + +// testGetterWithNilConfig tests a getter method with nil ExtensionConfig. +func testGetterWithNilConfig[T any](testFunc func(*config.ExtensionConfig) T, expected T) { + Convey("Test with nil ExtensionConfig", func() { + var extensionConfig *config.ExtensionConfig = nil + + result := testFunc(extensionConfig) + So(result, ShouldEqual, expected) + }) +} + +// testGetterWithNilSubConfig tests a getter method when ExtensionConfig exists but the relevant sub-config is nil. +func testGetterWithNilSubConfig[T any](subConfigName string, testFunc func(*config.ExtensionConfig) T, expected T) { + Convey("Test with ExtensionConfig but nil "+subConfigName, func() { + extensionConfig := &config.ExtensionConfig{} + + result := testFunc(extensionConfig) + So(result, ShouldEqual, expected) + }) +} + +// testGetterWithValidConfig tests a getter method with valid configuration. +func testGetterWithValidConfig[T any]( + subConfigName string, + testFunc func(*config.ExtensionConfig) T, + configBuilder func(bool) *config.ExtensionConfig, +) { + Convey("Test with valid "+subConfigName+" configuration", func() { + enabled := true + extensionConfig := configBuilder(enabled) + + result := testFunc(extensionConfig) + So(result, ShouldNotBeNil) + }) +} + +func TestExtensionConfig(t *testing.T) { + Convey("Test public methods", t, func() { + Convey("Test IsCveScanningEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsCveScanningEnabled) + testMethodWithNilSubConfig("Search", (*config.ExtensionConfig).IsCveScanningEnabled) + testMethodWithNilEnable("Search", (*config.ExtensionConfig).IsCveScanningEnabled) + testMethodWithDisabledEnable("Search", (*config.ExtensionConfig).IsCveScanningEnabled, buildSearchConfig) + testMethodWithEnabledEnable("Search", (*config.ExtensionConfig).IsCveScanningEnabled, buildSearchConfigWithCVE) + }) + + Convey("Test IsEventRecorderEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsEventRecorderEnabled) + testMethodWithNilSubConfig("Events", (*config.ExtensionConfig).IsEventRecorderEnabled) + testMethodWithNilEnable("Events", (*config.ExtensionConfig).IsEventRecorderEnabled) + testMethodWithDisabledEnable("Events", (*config.ExtensionConfig).IsEventRecorderEnabled, buildEventsConfig) + testMethodWithEnabledEnable("Events", (*config.ExtensionConfig).IsEventRecorderEnabled, buildEventsConfig) + }) + + Convey("Test IsSearchEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsSearchEnabled) + testMethodWithNilSubConfig("Search", (*config.ExtensionConfig).IsSearchEnabled) + testMethodWithNilEnable("Search", (*config.ExtensionConfig).IsSearchEnabled) + testMethodWithDisabledEnable("Search", (*config.ExtensionConfig).IsSearchEnabled, buildSearchConfig) + testMethodWithEnabledEnable("Search", (*config.ExtensionConfig).IsSearchEnabled, buildSearchConfig) + }) + + Convey("Test IsSyncEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsSyncEnabled) + testMethodWithNilSubConfig("Sync", (*config.ExtensionConfig).IsSyncEnabled) + testMethodWithNilEnable("Sync", (*config.ExtensionConfig).IsSyncEnabled) + testMethodWithDisabledEnable("Sync", (*config.ExtensionConfig).IsSyncEnabled, buildSyncConfig) + testMethodWithEnabledEnable("Sync", (*config.ExtensionConfig).IsSyncEnabled, buildSyncConfig) + }) + + Convey("Test IsScrubEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsScrubEnabled) + testMethodWithNilSubConfig("Scrub", (*config.ExtensionConfig).IsScrubEnabled) + testMethodWithNilEnable("Scrub", (*config.ExtensionConfig).IsScrubEnabled) + testMethodWithDisabledEnable("Scrub", (*config.ExtensionConfig).IsScrubEnabled, buildScrubConfig) + testMethodWithEnabledEnable("Scrub", (*config.ExtensionConfig).IsScrubEnabled, buildScrubConfig) + }) + + Convey("Test IsMetricsEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsMetricsEnabled) + testMethodWithNilSubConfig("Metrics", (*config.ExtensionConfig).IsMetricsEnabled) + testMethodWithNilEnable("Metrics", (*config.ExtensionConfig).IsMetricsEnabled) + testMethodWithDisabledEnable("Metrics", (*config.ExtensionConfig).IsMetricsEnabled, buildMetricsConfig) + testMethodWithEnabledEnable("Metrics", (*config.ExtensionConfig).IsMetricsEnabled, buildMetricsConfig) + }) + + Convey("Test IsCosignEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsCosignEnabled) + testMethodWithNilSubConfig("Trust", (*config.ExtensionConfig).IsCosignEnabled) + testMethodWithNilEnable("Trust", (*config.ExtensionConfig).IsCosignEnabled) + testMethodWithDisabledEnable("Trust", (*config.ExtensionConfig).IsCosignEnabled, buildTrustConfig) + testMethodWithEnabledEnable("Trust", (*config.ExtensionConfig).IsCosignEnabled, buildTrustConfigWithCosign) + }) + + Convey("Test IsNotationEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsNotationEnabled) + testMethodWithNilSubConfig("Trust", (*config.ExtensionConfig).IsNotationEnabled) + testMethodWithNilEnable("Trust", (*config.ExtensionConfig).IsNotationEnabled) + testMethodWithDisabledEnable("Trust", (*config.ExtensionConfig).IsNotationEnabled, buildTrustConfig) + testMethodWithEnabledEnable("Trust", (*config.ExtensionConfig).IsNotationEnabled, buildTrustConfigWithNotation) + }) + + Convey("Test IsImageTrustEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsImageTrustEnabled) + testMethodWithNilSubConfig("Trust", (*config.ExtensionConfig).IsImageTrustEnabled) + testMethodWithNilEnable("Trust", (*config.ExtensionConfig).IsImageTrustEnabled) + testMethodWithDisabledEnable("Trust", (*config.ExtensionConfig).IsImageTrustEnabled, buildTrustConfig) + testMethodWithEnabledEnable("Trust", (*config.ExtensionConfig).IsImageTrustEnabled, buildTrustConfig) + }) + + Convey("Test IsUIEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).IsUIEnabled) + testMethodWithNilSubConfig("UI", (*config.ExtensionConfig).IsUIEnabled) + testMethodWithNilEnable("UI", (*config.ExtensionConfig).IsUIEnabled) + testMethodWithDisabledEnable("UI", (*config.ExtensionConfig).IsUIEnabled, buildUIConfig) + testMethodWithEnabledEnable("UI", (*config.ExtensionConfig).IsUIEnabled, buildUIConfig) + }) + + Convey("Test AreUserPrefsEnabled()", func() { + testMethodWithNilConfig((*config.ExtensionConfig).AreUserPrefsEnabled) + testMethodWithNilSubConfig("Search", (*config.ExtensionConfig).AreUserPrefsEnabled) + testMethodWithNilEnable("UI", (*config.ExtensionConfig).AreUserPrefsEnabled) + testMethodWithDisabledEnable("Search", (*config.ExtensionConfig).AreUserPrefsEnabled, buildSearchConfig) + testMethodWithEnabledEnable("Search", (*config.ExtensionConfig).AreUserPrefsEnabled, buildSearchAndUIConfig) + }) + }) + + // Additional tests to verify thread-safety and internal method behavior + Convey("Test thread-safety and internal method coverage", t, func() { + // Create properly configured ExtensionConfigs for concurrent testing + searchEnabled := true + uiEnabled := true + searchConfig := &config.ExtensionConfig{} + searchConfig.Search = &config.SearchConfig{ + BaseConfig: config.BaseConfig{ + Enable: &searchEnabled, + }, + } + uiConfig := &config.ExtensionConfig{} + uiConfig.UI = &config.UIConfig{ + BaseConfig: config.BaseConfig{ + Enable: &uiEnabled, + }, + } + searchAndUIConfig := &config.ExtensionConfig{} + searchAndUIConfig.Search = &config.SearchConfig{ + BaseConfig: config.BaseConfig{ + Enable: &searchEnabled, + }, + } + searchAndUIConfig.UI = &config.UIConfig{ + BaseConfig: config.BaseConfig{ + Enable: &uiEnabled, + }, + } + + testConcurrentAccessWithConfig( + "IsSearchEnabled", + (*config.ExtensionConfig).IsSearchEnabled, + errIsSearchEnabledExpectedTrue, + searchConfig, + ) + testConcurrentAccessWithConfig( + "IsUIEnabled", + (*config.ExtensionConfig).IsUIEnabled, + errIsUIEnabledExpectedTrue, + uiConfig, + ) + testConcurrentAccessWithConfig( + "AreUserPrefsEnabled", + (*config.ExtensionConfig).AreUserPrefsEnabled, + errAreUserPrefsEnabledExpectedTrue, + searchAndUIConfig, + ) + + Convey("Test mixed concurrent access to all methods", func() { + searchEnabled := true + uiEnabled := true + extensionConfig := &config.ExtensionConfig{} + extensionConfig.Search = &config.SearchConfig{ + BaseConfig: config.BaseConfig{ + Enable: &searchEnabled, + }, + } + extensionConfig.UI = &config.UIConfig{ + BaseConfig: config.BaseConfig{ + Enable: &uiEnabled, + }, + } + + // Test mixed concurrent access to verify thread-safety across all methods + done := make(chan bool, 15) + errors := make(chan error, 15) + + // Launch goroutines for each method + for i := 0; i < 5; i++ { + go func() { + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + errors <- err + } else { + errors <- errPanicRecovered + } + } + done <- true + }() + + for j := 0; j < 50; j++ { + result := extensionConfig.IsSearchEnabled() + if !result { + errors <- errIsSearchEnabledExpectedTrue + + return + } + } + }() + } + + for i := 0; i < 5; i++ { + go func() { + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + errors <- err + } else { + errors <- errPanicRecovered + } + } + done <- true + }() + + for j := 0; j < 50; j++ { + result := extensionConfig.IsUIEnabled() + if !result { + errors <- errIsUIEnabledExpectedTrue + + return + } + } + }() + } + + for i := 0; i < 5; i++ { + go func() { + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + errors <- err + } else { + errors <- errPanicRecovered + } + } + done <- true + }() + + for j := 0; j < 50; j++ { + result := extensionConfig.AreUserPrefsEnabled() + if !result { + errors <- errAreUserPrefsEnabledExpectedTrue + + return + } + } + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 15; i++ { + <-done + } + + // Check for errors + close(errors) + + for err := range errors { + So(err, ShouldBeNil) + } + }) + + Convey("Test GetSearchCVEConfig()", func() { + testGetterWithNilConfig((*config.ExtensionConfig).GetSearchCVEConfig, nil) + testGetterWithNilSubConfig("Search", (*config.ExtensionConfig).GetSearchCVEConfig, nil) + testGetterWithValidConfig("Search", (*config.ExtensionConfig).GetSearchCVEConfig, buildSearchConfigWithCVE) + }) + + Convey("Test GetScrubInterval()", func() { + testGetterWithNilConfig((*config.ExtensionConfig).GetScrubInterval, 0) + testGetterWithNilSubConfig("Scrub", (*config.ExtensionConfig).GetScrubInterval, 0) + testGetterWithValidConfig("Scrub", (*config.ExtensionConfig).GetScrubInterval, buildScrubConfig) + }) + + Convey("Test GetSyncConfig()", func() { + testGetterWithNilConfig((*config.ExtensionConfig).GetSyncConfig, nil) + testGetterWithValidConfig("Sync", (*config.ExtensionConfig).GetSyncConfig, buildSyncConfig) + }) + + Convey("Test GetMetricsPrometheusConfig()", func() { + testGetterWithNilConfig((*config.ExtensionConfig).GetMetricsPrometheusConfig, nil) + testGetterWithNilSubConfig("Metrics", (*config.ExtensionConfig).GetMetricsPrometheusConfig, nil) + testGetterWithValidConfig("Metrics", (*config.ExtensionConfig).GetMetricsPrometheusConfig, buildMetricsConfig) + }) + + Convey("Test GetEventsConfig()", func() { + testGetterWithNilConfig((*config.ExtensionConfig).GetEventsConfig, nil) + testGetterWithValidConfig("Events", (*config.ExtensionConfig).GetEventsConfig, buildEventsConfig) + }) + }) +} diff --git a/pkg/extensions/extension_events.go b/pkg/extensions/extension_events.go index fffca7b0..4e58ec95 100644 --- a/pkg/extensions/extension_events.go +++ b/pkg/extensions/extension_events.go @@ -12,15 +12,16 @@ import ( ) func NewEventRecorder(config *config.Config, log log.Logger) (events.Recorder, error) { - if !config.IsEventRecorderEnabled() { + // Get extensions config safely + extensionsConfig := config.CopyExtensionsConfig() + if !extensionsConfig.IsEventRecorderEnabled() { log.Info().Msg("events disabled in configuration") return nil, zerr.ErrExtensionNotEnabled } - eventConfig := config.Extensions.Events - - if eventConfig.Sinks == nil || len(eventConfig.Sinks) == 0 { + eventConfig := extensionsConfig.GetEventsConfig() + if eventConfig == nil || eventConfig.Sinks == nil || len(eventConfig.Sinks) == 0 { log.Info().Msg("no sinks provided, skipping events extension setup") return nil, zerr.ErrExtensionNotEnabled diff --git a/pkg/extensions/extension_events_disabled.go b/pkg/extensions/extension_events_disabled.go index 9c5bbf06..81d73d14 100644 --- a/pkg/extensions/extension_events_disabled.go +++ b/pkg/extensions/extension_events_disabled.go @@ -11,7 +11,9 @@ import ( ) func NewEventRecorder(config *config.Config, log log.Logger) (events.Recorder, error) { - if !config.IsEventRecorderEnabled() { + // Get extensions config safely + extensionsConfig := config.CopyExtensionsConfig() + if !extensionsConfig.IsEventRecorderEnabled() { log.Info().Msg("events disabled in configuration") return nil, zerr.ErrExtensionNotEnabled diff --git a/pkg/extensions/extension_image_trust.go b/pkg/extensions/extension_image_trust.go index 6713edf4..1301fe46 100644 --- a/pkg/extensions/extension_image_trust.go +++ b/pkg/extensions/extension_image_trust.go @@ -27,7 +27,9 @@ func IsBuiltWithImageTrustExtension() bool { } func SetupImageTrustRoutes(conf *config.Config, router *mux.Router, metaDB mTypes.MetaDB, log log.Logger) { - if !conf.IsImageTrustEnabled() || (!conf.IsCosignEnabled() && !conf.IsNotationEnabled()) { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsImageTrustEnabled() || + (!extensionsConfig.IsCosignEnabled() && !extensionsConfig.IsNotationEnabled()) { log.Info().Msg("skip enabling the image trust routes as the config prerequisites are not met") return @@ -39,7 +41,7 @@ func SetupImageTrustRoutes(conf *config.Config, router *mux.Router, metaDB mType trust := ImageTrust{Conf: conf, ImageTrustStore: imgTrustStore, Log: log} allowedMethods := zcommon.AllowedMethods(http.MethodPost) - if conf.IsNotationEnabled() { + if extensionsConfig.IsNotationEnabled() { log.Info().Msg("setting up notation route") notationRouter := router.PathPrefix(constants.ExtNotation).Subrouter() @@ -51,7 +53,7 @@ func SetupImageTrustRoutes(conf *config.Config, router *mux.Router, metaDB mType notationRouter.Methods(allowedMethods...).HandlerFunc(trust.HandleNotationCertificateUpload) } - if conf.IsCosignEnabled() { + if extensionsConfig.IsCosignEnabled() { log.Info().Msg("setting up cosign route") cosignRouter := router.PathPrefix(constants.ExtCosign).Subrouter() @@ -153,7 +155,8 @@ func (trust *ImageTrust) HandleNotationCertificateUpload(response http.ResponseW func EnableImageTrustVerification(conf *config.Config, taskScheduler *scheduler.Scheduler, metaDB mTypes.MetaDB, log log.Logger, ) { - if !conf.IsImageTrustEnabled() { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsImageTrustEnabled() { return } @@ -165,7 +168,8 @@ func EnableImageTrustVerification(conf *config.Config, taskScheduler *scheduler. } func SetupImageTrustExtension(conf *config.Config, metaDB mTypes.MetaDB, log log.Logger) error { - if !conf.IsImageTrustEnabled() { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsImageTrustEnabled() { return nil } diff --git a/pkg/extensions/extension_metrics.go b/pkg/extensions/extension_metrics.go index 64beb284..8de32158 100644 --- a/pkg/extensions/extension_metrics.go +++ b/pkg/extensions/extension_metrics.go @@ -13,13 +13,10 @@ import ( ) func EnableMetricsExtension(config *config.Config, log log.Logger, rootDir string) { - if config.IsMetricsEnabled() && - config.Extensions.Metrics.Prometheus != nil { - if config.Extensions.Metrics.Prometheus.Path == "" { - config.Extensions.Metrics.Prometheus.Path = "/metrics" - - log.Warn().Msg("prometheus instrumentation path not set, changing to '/metrics'.") - } + // Get extensions config safely + extensionsConfig := config.CopyExtensionsConfig() + if extensionsConfig.IsMetricsEnabled() { + log.Info().Msg("metrics extension enabled") } else { log.Info().Msg("metrics config not provided, skipping metrics config update") } @@ -30,10 +27,15 @@ func SetupMetricsRoutes(config *config.Config, router *mux.Router, ) { log.Info().Msg("setting up metrics routes") - if config.IsMetricsEnabled() { - extRouter := router.PathPrefix(config.Extensions.Metrics.Prometheus.Path).Subrouter() - extRouter.Use(authnFunc) - extRouter.Use(authzFunc) - extRouter.Methods("GET").Handler(promhttp.Handler()) + // Get extensions config safely + extensionsConfig := config.CopyExtensionsConfig() + if extensionsConfig.IsMetricsEnabled() { + prometheusConfig := extensionsConfig.GetMetricsPrometheusConfig() + if prometheusConfig != nil { + extRouter := router.PathPrefix(prometheusConfig.Path).Subrouter() + extRouter.Use(authnFunc) + extRouter.Use(authzFunc) + extRouter.Methods("GET").Handler(promhttp.Handler()) + } } } diff --git a/pkg/extensions/extension_mgmt.go b/pkg/extensions/extension_mgmt.go index 3da8afb6..db99a87c 100644 --- a/pkg/extensions/extension_mgmt.go +++ b/pkg/extensions/extension_mgmt.go @@ -85,7 +85,8 @@ func (auth Auth) MarshalJSON() ([]byte, error) { } func SetupMgmtRoutes(conf *config.Config, router *mux.Router, log log.Logger) { - if !conf.IsMgmtEnabled() { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsSearchEnabled() { log.Info().Msg("skip enabling the mgmt route as the config prerequisites are not met") return diff --git a/pkg/extensions/extension_scrub.go b/pkg/extensions/extension_scrub.go index 6998eee6..9811d451 100644 --- a/pkg/extensions/extension_scrub.go +++ b/pkg/extensions/extension_scrub.go @@ -4,8 +4,6 @@ package extensions import ( - "time" - "zotregistry.dev/zot/v2/pkg/api/config" "zotregistry.dev/zot/v2/pkg/extensions/scrub" "zotregistry.dev/zot/v2/pkg/log" @@ -18,15 +16,10 @@ import ( func EnableScrubExtension(config *config.Config, log log.Logger, storeController storage.StoreController, sch *scheduler.Scheduler, ) { - if config.Extensions.Scrub != nil && - *config.Extensions.Scrub.Enable { - minScrubInterval, _ := time.ParseDuration("2h") - - if config.Extensions.Scrub.Interval < minScrubInterval { - config.Extensions.Scrub.Interval = minScrubInterval - - log.Warn().Msg("scrub interval set to too-short interval < 2h, changing scrub duration to 2 hours and continuing.") //nolint:lll // gofumpt conflicts with lll - } + // Get extensions config safely + extensionsConfig := config.CopyExtensionsConfig() + if extensionsConfig.IsScrubEnabled() { + scrubInterval := extensionsConfig.GetScrubInterval() processedRepos := make(map[string]struct{}) @@ -36,10 +29,12 @@ func EnableScrubExtension(config *config.Config, log log.Logger, storeController processedRepos: processedRepos, } - sch.SubmitGenerator(generator, config.Extensions.Scrub.Interval, scheduler.LowPriority) + sch.SubmitGenerator(generator, scrubInterval, scheduler.LowPriority) - if config.Storage.SubPaths != nil { - for route := range config.Storage.SubPaths { + // Get storage config safely + storageConfig := config.CopyStorageConfig() + if storageConfig.SubPaths != nil { + for route := range storageConfig.SubPaths { processedRepos := make(map[string]struct{}) generator := &taskGenerator{ @@ -48,7 +43,7 @@ func EnableScrubExtension(config *config.Config, log log.Logger, storeController processedRepos: processedRepos, } - sch.SubmitGenerator(generator, config.Extensions.Scrub.Interval, scheduler.LowPriority) + sch.SubmitGenerator(generator, scrubInterval, scheduler.LowPriority) } } } else { diff --git a/pkg/extensions/extension_search.go b/pkg/extensions/extension_search.go index d45d22c0..8c5458ae 100644 --- a/pkg/extensions/extension_search.go +++ b/pkg/extensions/extension_search.go @@ -33,12 +33,15 @@ func IsBuiltWithSearchExtension() bool { func GetCveScanner(conf *config.Config, storeController storage.StoreController, metaDB mTypes.MetaDB, log log.Logger, ) CveScanner { - if !conf.IsCveScanningEnabled() { + // Get extensions config safely + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsCveScanningEnabled() { return nil } - dbRepository := conf.Extensions.Search.CVE.Trivy.DBRepository - javaDBRepository := conf.Extensions.Search.CVE.Trivy.JavaDBRepository + cveConfig := extensionsConfig.GetSearchCVEConfig() + dbRepository := cveConfig.Trivy.DBRepository + javaDBRepository := cveConfig.Trivy.JavaDBRepository return cveinfo.NewScanner(storeController, metaDB, dbRepository, javaDBRepository, log) } @@ -46,8 +49,11 @@ func GetCveScanner(conf *config.Config, storeController storage.StoreController, func EnableSearchExtension(conf *config.Config, storeController storage.StoreController, metaDB mTypes.MetaDB, taskScheduler *scheduler.Scheduler, cveScanner CveScanner, log log.Logger, ) { - if conf.IsCveScanningEnabled() { - updateInterval := conf.Extensions.Search.CVE.UpdateInterval + // Get extensions config safely + extensionsConfig := conf.CopyExtensionsConfig() + if extensionsConfig.IsCveScanningEnabled() { + cveConfig := extensionsConfig.GetSearchCVEConfig() + updateInterval := cveConfig.UpdateInterval downloadTrivyDB(updateInterval, taskScheduler, cveScanner, log) startScanner(scanInterval, metaDB, taskScheduler, cveScanner, log) @@ -75,7 +81,8 @@ func startScanner(interval time.Duration, metaDB mTypes.MetaDB, sch *scheduler.S func SetupSearchRoutes(conf *config.Config, router *mux.Router, storeController storage.StoreController, metaDB mTypes.MetaDB, cveScanner CveScanner, log log.Logger, ) { - if !conf.IsSearchEnabled() { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsSearchEnabled() { log.Info().Msg("skip enabling the search route as the config prerequisites are not met") return @@ -84,7 +91,7 @@ func SetupSearchRoutes(conf *config.Config, router *mux.Router, storeController log.Info().Msg("setting up search routes") var cveInfo cveinfo.CveInfo - if conf.IsCveScanningEnabled() { + if extensionsConfig.IsCveScanningEnabled() { cveInfo = cveinfo.NewCVEInfo(cveScanner, metaDB, log) } else { cveInfo = nil diff --git a/pkg/extensions/extension_sync.go b/pkg/extensions/extension_sync.go index 17619b39..9ffaf29f 100644 --- a/pkg/extensions/extension_sync.go +++ b/pkg/extensions/extension_sync.go @@ -21,13 +21,18 @@ import ( func EnableSyncExtension(config *config.Config, metaDB mTypes.MetaDB, storeController storage.StoreController, sch *scheduler.Scheduler, log log.Logger, ) (*sync.BaseOnDemand, error) { - if config.Extensions.Sync != nil && *config.Extensions.Sync.Enable { - onDemand := sync.NewOnDemand(log) + // Get extensions config safely + extensionsConfig := config.CopyExtensionsConfig() + httpAddress := config.GetHTTPAddress() + httpPort := config.GetHTTPPort() - for _, registryConfig := range config.Extensions.Sync.Registries { - registryConfig := registryConfig + if extensionsConfig.IsSyncEnabled() { + onDemand := sync.NewOnDemand(log) + syncConfig := extensionsConfig.GetSyncConfig() + + for _, registryConfig := range syncConfig.Registries { if len(registryConfig.URLs) > 1 { - if err := removeSelfURLs(config, ®istryConfig, log); err != nil { + if err := removeSelfURLs(httpAddress, httpPort, ®istryConfig, log); err != nil { return nil, err } } @@ -45,11 +50,12 @@ func EnableSyncExtension(config *config.Config, metaDB mTypes.MetaDB, continue } - tmpDir := config.Extensions.Sync.DownloadDir - credsPath := config.Extensions.Sync.CredentialsFile - clusterCfg := config.Cluster + tmpDir := syncConfig.DownloadDir + credsPath := syncConfig.CredentialsFile + // Get cluster config safely + clusterConfig := config.CopyClusterConfig() - service, err := sync.New(registryConfig, credsPath, clusterCfg, tmpDir, storeController, metaDB, log) + service, err := sync.New(registryConfig, credsPath, clusterConfig, tmpDir, storeController, metaDB, log) if err != nil { log.Error().Err(err).Msg("failed to initialize sync extension") @@ -102,10 +108,9 @@ func getLocalIPs() ([]string, error) { return localIPs, nil } -func removeSelfURLs(config *config.Config, registryConfig *syncconf.RegistryConfig, log log.Logger) error { +func removeSelfURLs(httpAddress, httpPort string, registryConfig *syncconf.RegistryConfig, log log.Logger) error { // get IP from config - port := config.HTTP.Port - selfAddress := net.JoinHostPort(config.HTTP.Address, port) + selfAddress := net.JoinHostPort(httpAddress, httpPort) // get all local IPs from interfaces localIPs, err := getLocalIPs() @@ -148,8 +153,8 @@ func removeSelfURLs(config *config.Config, registryConfig *syncconf.RegistryConf for _, localIP := range localIPs { // if ip resolved from hostname/dns is equal with any local ip for _, ip := range ips { - if (ip.IsLoopback() && (url.Port() == port)) || - (net.JoinHostPort(ip.String(), url.Port()) == net.JoinHostPort(localIP, port)) { + if (ip.IsLoopback() && (url.Port() == httpPort)) || + (net.JoinHostPort(ip.String(), url.Port()) == net.JoinHostPort(localIP, httpPort)) { registryConfig.URLs = append(registryConfig.URLs[:idx], registryConfig.URLs[idx+1:]...) removed = true diff --git a/pkg/extensions/extension_ui.go b/pkg/extensions/extension_ui.go index a4aa2741..cf4880e7 100644 --- a/pkg/extensions/extension_ui.go +++ b/pkg/extensions/extension_ui.go @@ -63,7 +63,8 @@ func addUISecurityHeaders(h http.Handler) http.HandlerFunc { //nolint:varnamelen func SetupUIRoutes(conf *config.Config, router *mux.Router, log log.Logger, ) { - if !conf.IsUIEnabled() { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.IsUIEnabled() { log.Info().Msg("skip enabling the ui route as the config prerequisites are not met") return diff --git a/pkg/extensions/extension_userprefs.go b/pkg/extensions/extension_userprefs.go index 7823e135..bad09662 100644 --- a/pkg/extensions/extension_userprefs.go +++ b/pkg/extensions/extension_userprefs.go @@ -29,7 +29,8 @@ func IsBuiltWithUserPrefsExtension() bool { func SetupUserPreferencesRoutes(conf *config.Config, router *mux.Router, metaDB mTypes.MetaDB, log log.Logger, ) { - if !conf.AreUserPrefsEnabled() { + extensionsConfig := conf.CopyExtensionsConfig() + if !extensionsConfig.AreUserPrefsEnabled() { log.Info().Msg("skip enabling the user preferences route as the config prerequisites are not met") return diff --git a/pkg/extensions/extensions_test.go b/pkg/extensions/extensions_test.go index 401f600d..e0c7c667 100644 --- a/pkg/extensions/extensions_test.go +++ b/pkg/extensions/extensions_test.go @@ -107,8 +107,7 @@ func TestMetricsExtension(t *testing.T) { data, _ := os.ReadFile(logFile.Name()) - So(string(data), ShouldContainSubstring, - "prometheus instrumentation path not set, changing to '/metrics'.") + So(string(data), ShouldContainSubstring, "metrics extension enabled") }) } diff --git a/pkg/extensions/get_extensions.go b/pkg/extensions/get_extensions.go index 1fe54c43..f59894af 100644 --- a/pkg/extensions/get_extensions.go +++ b/pkg/extensions/get_extensions.go @@ -16,23 +16,24 @@ func GetExtensions(config *config.Config) distext.ExtensionList { endpoints := []string{} extensions := []distext.Extension{} - if config.IsNotationEnabled() && IsBuiltWithImageTrustExtension() { + extensionsConfig := config.CopyExtensionsConfig() + if extensionsConfig.IsNotationEnabled() && IsBuiltWithImageTrustExtension() { endpoints = append(endpoints, constants.FullNotation) } - if config.IsCosignEnabled() && IsBuiltWithImageTrustExtension() { + if extensionsConfig.IsCosignEnabled() && IsBuiltWithImageTrustExtension() { endpoints = append(endpoints, constants.FullCosign) } - if config.IsSearchEnabled() && IsBuiltWithSearchExtension() { + if extensionsConfig.IsSearchEnabled() && IsBuiltWithSearchExtension() { endpoints = append(endpoints, constants.FullSearchPrefix) } - if config.AreUserPrefsEnabled() && IsBuiltWithUserPrefsExtension() { + if extensionsConfig.AreUserPrefsEnabled() && IsBuiltWithUserPrefsExtension() { endpoints = append(endpoints, constants.FullUserPrefs) } - if config.IsMgmtEnabled() && IsBuiltWithMGMTExtension() { + if extensionsConfig.IsSearchEnabled() && IsBuiltWithMGMTExtension() { endpoints = append(endpoints, constants.FullMgmt) } diff --git a/pkg/extensions/monitoring/common.go b/pkg/extensions/monitoring/common.go index 5968e43f..e9be85d4 100644 --- a/pkg/extensions/monitoring/common.go +++ b/pkg/extensions/monitoring/common.go @@ -14,6 +14,8 @@ type MetricServer interface { ForceSendMetric(interface{}) ReceiveMetrics() interface{} IsEnabled() bool + // Stop gracefully shuts down the metrics server + Stop() } func GetDirSize(path string) (int64, error) { diff --git a/pkg/extensions/monitoring/extension.go b/pkg/extensions/monitoring/extension.go index b05d1f6c..d9a9b66f 100644 --- a/pkg/extensions/monitoring/extension.go +++ b/pkg/extensions/monitoring/extension.go @@ -137,6 +137,11 @@ type metricServer struct { log log.Logger } +// Stop gracefully shuts down the metrics server (no-op for this implementation). +func (ms *metricServer) Stop() { + // This is a no-op implementation for the disabled metrics server +} + func GetDefaultBuckets() []float64 { return []float64{.05, .5, 1, 5, 30, 60, 600} } diff --git a/pkg/extensions/monitoring/minimal.go b/pkg/extensions/monitoring/minimal.go index 4b90554c..5b8b2ec2 100644 --- a/pkg/extensions/monitoring/minimal.go +++ b/pkg/extensions/monitoring/minimal.go @@ -49,6 +49,7 @@ type metricServer struct { bucketsF2S map[float64]string // float64 to string conversion of buckets label log log.Logger lock *sync.RWMutex + stopChan chan struct{} // Channel to signal shutdown } type MetricsInfo struct { @@ -142,19 +143,35 @@ func (ms *metricServer) IsEnabled() bool { return ms.enabled } +// Stop gracefully shuts down the metrics server. +func (ms *metricServer) Stop() { + close(ms.stopChan) +} + func (ms *metricServer) Run() { sendAfter := make(chan time.Duration, 1) // periodically send a notification to the metric server to check if we can disable metrics go func() { for { - t := metricsScrapeCheckInterval - time.Sleep(t) - sendAfter <- t + select { + case <-ms.stopChan: + return + default: + t := metricsScrapeCheckInterval + time.Sleep(t) + select { + case sendAfter <- t: + case <-ms.stopChan: + return + } + } } }() for { select { + case <-ms.stopChan: + return case <-ms.cacheChan: ms.lastCheck = time.Now() // make a copy of cache values to prevent data race @@ -239,6 +256,7 @@ func NewMetricsServer(enabled bool, log log.Logger) MetricServer { bucketsF2S: bucketsFloat2String, log: log, lock: &sync.RWMutex{}, + stopChan: make(chan struct{}), } go ms.Run() diff --git a/pkg/extensions/monitoring/monitoring_test.go b/pkg/extensions/monitoring/monitoring_test.go index d179717b..20698d4f 100644 --- a/pkg/extensions/monitoring/monitoring_test.go +++ b/pkg/extensions/monitoring/monitoring_test.go @@ -48,6 +48,11 @@ func TestExtensionMetrics(t *testing.T) { ctlr := api.NewController(conf) So(ctlr, ShouldNotBeNil) + // Write image before starting controller to avoid race condition with garbage collection + srcStorageCtlr := ociutils.GetDefaultStoreController(rootDir, ctlr.Log) + err := WriteImageToFileSystem(CreateDefaultImage(), "alpine", "0.0.1", srcStorageCtlr) + So(err, ShouldBeNil) + cm := test.NewControllerManager(ctlr) cm.StartAndWait(port) defer cm.StopServer() @@ -64,10 +69,6 @@ func TestExtensionMetrics(t *testing.T) { monitoring.IncDownloadCounter(ctlr.Metrics, "alpine") monitoring.IncUploadCounter(ctlr.Metrics, "alpine") - srcStorageCtlr := ociutils.GetDefaultStoreController(rootDir, ctlr.Log) - err := WriteImageToFileSystem(CreateDefaultImage(), "alpine", "0.0.1", srcStorageCtlr) - So(err, ShouldBeNil) - monitoring.SetStorageUsage(ctlr.Metrics, rootDir, "alpine") monitoring.ObserveStorageLockLatency(ctlr.Metrics, time.Millisecond, rootDir, "RWLock") diff --git a/pkg/extensions/search/search_test.go b/pkg/extensions/search/search_test.go index 7ae097bc..4c76d527 100644 --- a/pkg/extensions/search/search_test.go +++ b/pkg/extensions/search/search_test.go @@ -602,7 +602,7 @@ func TestRepoListWithNewestImage(t *testing.T) { // Delete config blob and try. err = os.Remove(path.Join(subRootDir, "a/zot-test/blobs/sha256", configDigest.Encoded())) - if err != nil { + if err != nil && !os.IsNotExist(err) { panic(err) } @@ -614,7 +614,7 @@ func TestRepoListWithNewestImage(t *testing.T) { err = os.Remove(path.Join(subRootDir, "a/zot-test/blobs/sha256", manifestDigest.Encoded())) - if err != nil { + if err != nil && !os.IsNotExist(err) { panic(err) } @@ -625,7 +625,7 @@ func TestRepoListWithNewestImage(t *testing.T) { So(resp.StatusCode(), ShouldEqual, 200) err = os.Remove(path.Join(rootDir, "zot-test/blobs/sha256", configDigest.Encoded())) - if err != nil { + if err != nil && !os.IsNotExist(err) { panic(err) } @@ -637,7 +637,7 @@ func TestRepoListWithNewestImage(t *testing.T) { // Delete manifest blob also and try err = os.Remove(path.Join(rootDir, "zot-test/blobs/sha256", manifestDigest.Encoded())) - if err != nil { + if err != nil && !os.IsNotExist(err) { panic(err) } diff --git a/pkg/extensions/sync/sync_internal_test.go b/pkg/extensions/sync/sync_internal_test.go index b8773d0b..fbc285b7 100644 --- a/pkg/extensions/sync/sync_internal_test.go +++ b/pkg/extensions/sync/sync_internal_test.go @@ -175,8 +175,6 @@ func TestService(t *testing.T) { onDemand.Add(service) ctx := context.Background() - // === IMAGE SYNC CONTINUE PATH TEST === - // Step 1: Verify empty requestStore initially initialImageCount := 0 onDemand.requestStore.Range(func(key, value interface{}) bool { @@ -219,8 +217,6 @@ func TestService(t *testing.T) { So(exists, ShouldBeTrue) So(value, ShouldEqual, struct{}{}) // Should still be pre-populated value - // === REFERRER SYNC CONTINUE PATH TEST === - // Step 7: Verify current state before referrer test - we should have 1 request initialReferrerCount := 0 onDemand.requestStore.Range(func(key, value interface{}) bool { diff --git a/pkg/extensions/sync/sync_test.go b/pkg/extensions/sync/sync_test.go index 997dbe4c..603cc657 100644 --- a/pkg/extensions/sync/sync_test.go +++ b/pkg/extensions/sync/sync_test.go @@ -814,11 +814,14 @@ func TestOnDemandWithScaleOutCluster(t *testing.T) { Registries: []syncconf.RegistryConfig{syncRegistryConfig}, } + // Get dynamic ports for cluster members + clusterPorts := test.GetFreePorts(2) + // cluster config for member 1. clusterCfgDownstream1 := config.ClusterConfig{ Members: []string{ - "127.0.0.1:43222", - "127.0.0.1:43223", + "127.0.0.1:" + clusterPorts[0], + "127.0.0.1:" + clusterPorts[1], }, HashKey: "loremipsumdolors", } @@ -827,11 +830,11 @@ func TestOnDemandWithScaleOutCluster(t *testing.T) { clusterCfgDownstream2 := clusterCfgDownstream1 dctrl1, dctrl1BaseURL, destDir1, dstClient1 := makeInsecureDownstreamServerFixedPort( - t, "43222", syncConfig, &clusterCfgDownstream1) + t, clusterPorts[0], syncConfig, &clusterCfgDownstream1) dctrl1Scm := test.NewControllerManager(dctrl1) dctrl2, dctrl2BaseURL, destDir2, dstClient2 := makeInsecureDownstreamServerFixedPort( - t, "43223", syncConfig, &clusterCfgDownstream2) + t, clusterPorts[1], syncConfig, &clusterCfgDownstream2) dctrl2Scm := test.NewControllerManager(dctrl2) dctrl1Scm.StartAndWait(dctrl1.Config.HTTP.Port) @@ -983,11 +986,14 @@ func TestOnDemandWithScaleOutClusterWithReposNotAddedForSync(t *testing.T) { Registries: []syncconf.RegistryConfig{syncRegistryConfig}, } + // Get dynamic ports for cluster members + clusterPorts := test.GetFreePorts(2) + // cluster config for member 1. clusterCfgDownstream1 := config.ClusterConfig{ Members: []string{ - "127.0.0.1:43222", - "127.0.0.1:43223", + "127.0.0.1:" + clusterPorts[0], + "127.0.0.1:" + clusterPorts[1], }, HashKey: "loremipsumdolors", } @@ -996,11 +1002,11 @@ func TestOnDemandWithScaleOutClusterWithReposNotAddedForSync(t *testing.T) { clusterCfgDownstream2 := clusterCfgDownstream1 dctrl1, dctrl1BaseURL, destDir1, dstClient1 := makeInsecureDownstreamServerFixedPort( - t, "43222", syncConfig, &clusterCfgDownstream1) + t, clusterPorts[0], syncConfig, &clusterCfgDownstream1) dctrl1Scm := test.NewControllerManager(dctrl1) dctrl2, dctrl2BaseURL, destDir2, dstClient2 := makeInsecureDownstreamServerFixedPort( - t, "43223", syncConfig, &clusterCfgDownstream2) + t, clusterPorts[1], syncConfig, &clusterCfgDownstream2) dctrl2Scm := test.NewControllerManager(dctrl2) dctrl1Scm.StartAndWait(dctrl1.Config.HTTP.Port) @@ -1878,15 +1884,20 @@ func TestPeriodicallyWithScaleOutCluster(t *testing.T) { // zot-test is managed by member index 1. // zot-cve-test is managed by member index 0. // zot-alpine-test is managed by member index 1. + + // Get dynamic ports for cluster members + clusterPorts := test.GetFreePorts(2) + clusterCfg := config.ClusterConfig{ Members: []string{ - "127.0.0.1:100", - "127.0.0.1:42000", + "127.0.0.1:" + clusterPorts[0], + "127.0.0.1:" + clusterPorts[1], }, HashKey: "loremipsumdolors", } - dctlr, destBaseURL, destDir, destClient := makeInsecureDownstreamServerFixedPort(t, "42000", syncConfig, &clusterCfg) + dctlr, destBaseURL, destDir, destClient := makeInsecureDownstreamServerFixedPort(t, + clusterPorts[1], syncConfig, &clusterCfg) dcm := test.NewControllerManager(dctlr) dcm.StartAndWait(dctlr.Config.HTTP.Port) diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index ea2ab9ed..08efc18d 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -549,8 +549,9 @@ func (scheduler *Scheduler) SubmitGenerator(taskGenerator TaskGenerator, interva } func getNumWorkers(cfg *config.Config) int { - if cfg.Scheduler != nil && cfg.Scheduler.NumWorkers != 0 { - return cfg.Scheduler.NumWorkers + schedulerConfig := cfg.CopySchedulerConfig() + if schedulerConfig != nil && schedulerConfig.NumWorkers != 0 { + return schedulerConfig.NumWorkers } return runtime.NumCPU() * NumWorkersMultiplier diff --git a/pkg/storage/gc/gc_test.go b/pkg/storage/gc/gc_test.go index 592a5fb5..bb2ddc85 100644 --- a/pkg/storage/gc/gc_test.go +++ b/pkg/storage/gc/gc_test.go @@ -63,6 +63,7 @@ func TestGarbageCollectAndRetentionMetaDB(t *testing.T) { audit := zlog.NewAuditLogger("debug", "/dev/null") metrics := monitoring.NewMetricsServer(false, log) + defer metrics.Stop() // Clean up metrics server to prevent resource leaks trueVal := true @@ -1317,6 +1318,7 @@ func TestGarbageCollectDeletion(t *testing.T) { audit := zlog.NewAuditLogger("debug", "/dev/null") metrics := monitoring.NewMetricsServer(false, log) + defer metrics.Stop() // Clean up metrics server to prevent resource leaks trueVal := true falseVal := false @@ -1755,6 +1757,7 @@ func TestGarbageCollectAndRetentionNoMetaDB(t *testing.T) { audit := zlog.NewAuditLogger("debug", "/dev/null") metrics := monitoring.NewMetricsServer(false, log) + defer metrics.Stop() // Clean up metrics server to prevent resource leaks trueVal := true diff --git a/pkg/test/common/utils.go b/pkg/test/common/utils.go index 507043d9..5f9ff390 100644 --- a/pkg/test/common/utils.go +++ b/pkg/test/common/utils.go @@ -1,14 +1,18 @@ package common import ( + "bytes" "errors" "fmt" + "io" "math/rand" "net/http" "net/url" "os" "path" "strconv" + "strings" + "sync" "time" "github.com/phayes/freeport" @@ -143,6 +147,23 @@ func GetFreePort() string { return strconv.Itoa(port) } +// GetFreePorts returns multiple unique free ports, useful for cluster tests. +func GetFreePorts(count int) []string { + // Use the freeport library's GetFreePorts function which guarantees uniqueness + intPorts, err := freeport.GetFreePorts(count) + if err != nil { + panic(err) + } + + // Convert to strings + ports := make([]string, count) + for i, port := range intPorts { + ports[i] = strconv.Itoa(port) + } + + return ports +} + func GetBaseURL(port string) string { return fmt.Sprintf(BaseURL, port) } @@ -233,3 +254,78 @@ func ContainSameElements[T comparable](list1, list2 []T) bool { return true } + +// ThreadSafeLogBuffer is a thread-safe wrapper around bytes.Buffer for concurrent log capture. +type ThreadSafeLogBuffer struct { + buffer *bytes.Buffer + mutex sync.RWMutex +} + +// NewThreadSafeLogBuffer creates a new thread-safe log buffer. +func NewThreadSafeLogBuffer() *ThreadSafeLogBuffer { + return &ThreadSafeLogBuffer{ + buffer: &bytes.Buffer{}, + } +} + +// Write implements io.Writer interface with thread safety. +func (tsb *ThreadSafeLogBuffer) Write(p []byte) (int, error) { + tsb.mutex.Lock() + defer tsb.mutex.Unlock() + + return tsb.buffer.Write(p) +} + +// String returns the buffer contents as a string with thread safety. +func (tsb *ThreadSafeLogBuffer) String() string { + tsb.mutex.RLock() + defer tsb.mutex.RUnlock() + + return tsb.buffer.String() +} + +// WaitForLogMessages waits for a specific number of log messages to appear in the log buffer +// within the given timeout. This is useful for verifying goroutine termination or other +// asynchronous operations that log specific messages. +// +// Parameters: +// - logBuffer: A ThreadSafeLogBuffer that captures log output +// - message: The log message to search for (e.g., "htpasswd watcher terminating...") +// - minCount: Minimum number of occurrences to wait for +// - timeout: Maximum time to wait for the messages +// +// Returns: +// - true if at least minCount messages were found within the timeout +// - false if the timeout was reached before finding enough messages +func WaitForLogMessages(logBuffer *ThreadSafeLogBuffer, message string, minCount int, timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + logOutput := logBuffer.String() + actualCount := strings.Count(logOutput, message) + + if actualCount >= minCount { + return true + } + + time.Sleep(10 * time.Millisecond) + } + + return false +} + +// CreateLogCapturingWriter creates a multi-writer that captures log output to a thread-safe buffer +// while also writing to the original writer (typically os.Stdout). This is useful for +// tests that need to programmatically verify log messages. +// +// Parameters: +// - originalWriter: The original writer to continue writing to (e.g., os.Stdout) +// +// Returns: +// - A ThreadSafeLogBuffer that captures the log output +// - An io.Writer that writes to both the original writer and the buffer +func CreateLogCapturingWriter(originalWriter io.Writer) (*ThreadSafeLogBuffer, io.Writer) { + logBuffer := NewThreadSafeLogBuffer() + multiWriter := io.MultiWriter(originalWriter, logBuffer) + + return logBuffer, multiWriter +} diff --git a/pkg/test/common/utils_test.go b/pkg/test/common/utils_test.go index 28df644f..f90558db 100644 --- a/pkg/test/common/utils_test.go +++ b/pkg/test/common/utils_test.go @@ -58,3 +58,143 @@ func TestControllerManager(t *testing.T) { So(func() { ctlrManager.RunServer() }, ShouldPanic) }) } + +func TestWaitForLogMessages(t *testing.T) { + Convey("Test WaitForLogMessages", t, func() { + Convey("should return true when message count reaches minimum", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write some log messages + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("Server started successfully\n")) + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("Processing request\n")) + _, _ = logBuffer.Write([]byte("Starting server\n")) + + // Wait for "Starting server" message to appear at least 3 times + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 3, 100*time.Millisecond) + + So(result, ShouldBeTrue) + }) + + Convey("should return false when message count never reaches minimum", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write some log messages (only 1 occurrence of target message) + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("Server started successfully\n")) + _, _ = logBuffer.Write([]byte("Processing request\n")) + + // Wait for "Starting server" message to appear at least 3 times + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 3, 50*time.Millisecond) + + So(result, ShouldBeFalse) + }) + + Convey("should return true immediately when count already meets requirement", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write messages before calling WaitForLogMessages + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("Starting server\n")) + + // Wait for "Starting server" message to appear at least 3 times + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 3, 100*time.Millisecond) + + So(result, ShouldBeTrue) + }) + + Convey("should handle empty log buffer", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Wait for any message in empty buffer + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 1, 50*time.Millisecond) + + So(result, ShouldBeFalse) + }) + + Convey("should handle partial message matches", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write messages with partial matches + _, _ = logBuffer.Write([]byte("Starting server process\n")) + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("Starting server instance\n")) + + // Wait for exact "Starting server" message (not partial matches) + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 2, 100*time.Millisecond) + + So(result, ShouldBeTrue) + }) + + Convey("should timeout after specified duration", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write only one message + _, _ = logBuffer.Write([]byte("Starting server\n")) + + // Wait for 3 occurrences with short timeout + start := time.Now() + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 3, 10*time.Millisecond) + duration := time.Since(start) + + So(result, ShouldBeFalse) + So(duration, ShouldBeGreaterThanOrEqualTo, 10*time.Millisecond) + So(duration, ShouldBeLessThan, 50*time.Millisecond) // Should timeout quickly + }) + + Convey("should handle concurrent writes", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Simulate concurrent writes + go func() { + for i := 0; i < 5; i++ { + _, _ = logBuffer.Write([]byte("Starting server\n")) + + time.Sleep(5 * time.Millisecond) + } + }() + + // Wait for messages to appear + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 3, 100*time.Millisecond) + + So(result, ShouldBeTrue) + }) + + Convey("should handle case-sensitive message matching", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write messages with different cases + _, _ = logBuffer.Write([]byte("Starting server\n")) + _, _ = logBuffer.Write([]byte("starting server\n")) + _, _ = logBuffer.Write([]byte("STARTING SERVER\n")) + + // Wait for exact case match + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 2, 100*time.Millisecond) + + So(result, ShouldBeFalse) // Only 1 exact match + }) + + Convey("should handle zero minimum count", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Wait for 0 occurrences (should always return true) + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 0, 100*time.Millisecond) + + So(result, ShouldBeTrue) + }) + + Convey("should handle very short timeout", func() { + logBuffer := tcommon.NewThreadSafeLogBuffer() + + // Write a message + _, _ = logBuffer.Write([]byte("Starting server\n")) + + // Wait with very short timeout + result := tcommon.WaitForLogMessages(logBuffer, "Starting server", 1, 1*time.Millisecond) + + So(result, ShouldBeTrue) // Should find it immediately + }) + }) +}