From 08fae9104dea0ee8ded87898642b431239bb4b5f Mon Sep 17 00:00:00 2001 From: Andrei Aaron Date: Thu, 11 Dec 2025 20:08:32 +0200 Subject: [PATCH] feat: support mTLS-only authn/authz with AccessControl and allow combining mTLS with other auth mechanisms (#3624) * feat: support mTLS-only authn/authz with AccessControl and allow combining mTLS with other auth mechanisms Signed-off-by: Ivan Arkhipov * refactor: improve authentication logic and TLS certificate generation - Fix mTLS authentication to use only leaf certificate instead of iterating through all certificates in the chain - Reject Authorization headers when corresponding auth method is disabled, regardless of mTLS status (security improvement) - Simplify authentication switch statement ordering and logic - Move ErrUserDataNotFound error handling into sessionAuthn method - Refactor TLS certificate generation to use Options pattern with CertificateOptions struct for better extensibility - Consolidate duplicate certificate generation code into helper functions (generateCertificate, parseCA, initializeTemplate, applyOptions) - Rename certificate generation functions for clarity: - GenerateCertWithCN -> GenerateClientCert - GenerateSelfSignedCertWithCN -> GenerateClientSelfSignedCert - Add support for SAN settings including email addresses in certificates - Update tests to reflect new authentication behavior and certificate API This commit improves both the security posture (rejecting disabled auth methods) and code maintainability (consolidated certificate generation). Signed-off-by: Andrei Aaron * fix: guard against multiple Authorization headers Signed-off-by: Andrei Aaron --------- Signed-off-by: Ivan Arkhipov Signed-off-by: Andrei Aaron Co-authored-by: Ivan Arkhipov --- .github/workflows/test.yaml | 2 +- pkg/api/authn.go | 214 ++++---- pkg/api/authn_test.go | 968 ++++++++++++++++++++++++++++++++++ pkg/api/config/config.go | 56 +- pkg/api/config/config_test.go | 313 ----------- pkg/api/controller.go | 9 +- pkg/api/controller_test.go | 187 ++++--- pkg/cli/client/client_test.go | 3 +- pkg/cli/server/root.go | 5 +- pkg/test/tls/tls.go | 407 ++++++++++++++ pkg/test/tls/tls_test.go | 438 +++++++++++++++ 11 files changed, 2066 insertions(+), 536 deletions(-) create mode 100644 pkg/test/tls/tls.go create mode 100644 pkg/test/tls/tls_test.go diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index c16083e4..99b7e0ed 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -174,7 +174,7 @@ jobs: make covhtml mkdir unified-coverage cp coverage.txt coverage.html unified-coverage/ - - name: upload html coverage + - name: upload unified-coverage as build artifact uses: actions/upload-artifact@v5 with: name: unified-coverage diff --git a/pkg/api/authn.go b/pkg/api/authn.go index d7681552..b5864296 100644 --- a/pkg/api/authn.go +++ b/pkg/api/authn.go @@ -91,6 +91,11 @@ func (amw *AuthnMiddleware) sessionAuthn(ctlr *Controller, userAc *reqCtx.UserAc if err != nil { ctlr.Log.Err(err).Str("identity", identity).Msg("failed to get user profile in DB") + if errors.Is(err, zerr.ErrUserDataNotFound) { + // we handle this case as an authentication failure, not an internal server error + err = nil + } + return false, err } @@ -100,6 +105,57 @@ func (amw *AuthnMiddleware) sessionAuthn(ctlr *Controller, userAc *reqCtx.UserAc return true, nil } +func (amw *AuthnMiddleware) mTLSAuthn(ctlr *Controller, userAc *reqCtx.UserAccessControl, + request *http.Request, +) (bool, error) { + // Check if mTLS is configured and client certificates are present + if request.TLS == nil || len(request.TLS.PeerCertificates) == 0 { + return false, nil + } + + // Check if client certificate has verified chain + verifiedChains := request.TLS.VerifiedChains + if len(verifiedChains) == 0 || len(verifiedChains[0]) == 0 { + ctlr.Log.Debug().Msg("mTLS authentication failed - user provided certificate not signed by CA") + + return false, nil + } + + // Extract identity from certificate + leafCert := request.TLS.PeerCertificates[0] + + identity := leafCert.Subject.CommonName + if identity == "" { + return false, nil + } + + // Process request with mTLS identity + var groups []string + + accessControl := ctlr.Config.CopyAccessControlConfig() + if accessControl != nil { + ac := NewAccessController(ctlr.Config) + groups = ac.getUserGroups(identity) + } + + userAc.SetUsername(identity) + userAc.AddGroups(groups) + userAc.SaveOnRequest(request) + + // Update user groups in MetaDB if available + if ctlr.MetaDB != nil { + if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil { + ctlr.Log.Error().Err(err).Str("identity", identity).Msg("failed to update user profile") + + return false, err + } + } + + ctlr.Log.Debug().Str("identity", identity).Msg("mTLS authentication successful") + + return true, nil +} + func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAccessControl, response http.ResponseWriter, request *http.Request, ) (bool, error) { @@ -263,14 +319,6 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun // Get auth config once to avoid multiple calls authConfig := ctlr.Config.CopyAuthConfig() - // no password based authN, if neither LDAP nor HTTP BASIC is enabled - if !authConfig.IsBasicAuthnEnabled() { - return noPasswdAuth(ctlr) - } - - delay := authConfig.GetFailDelay() - realm := ctlr.Config.GetRealm() - // ldap and htpasswd based authN if authConfig.IsLdapAuthEnabled() { ldapConfig := authConfig.LDAP @@ -357,6 +405,11 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun isMgmtRequested := request.RequestURI == constants.FullMgmt + // Get auth config safely + authConfig := ctlr.Config.CopyAuthConfig() + delay := authConfig.GetFailDelay() + realm := ctlr.Config.GetRealm() + // Get access control config safely accessControlConfig := ctlr.Config.CopyAccessControlConfig() allowAnonymous := accessControlConfig != nil && accessControlConfig.AnonymousPolicyExists() @@ -366,57 +419,61 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun // if it will not be populated by authn handlers, this represents an anonymous user userAc.SaveOnRequest(request) - // try basic auth if authorization header is given - if !isAuthorizationHeaderEmpty(request) { //nolint: gocritic - //nolint: contextcheck - authenticated, err := amw.basicAuthn(ctlr, userAc, response, request) + authenticated := false + + var err error + + // Switch authentication methods based on provided request context + switch { + // Reject requests with multiple Authorization headers as a security measure + case hasMultipleAuthorizationHeaders(request): + authenticated = false + + // The authorization header presence is an explicit attempt to use basic authentication + case !isAuthorizationHeaderEmpty(request) && authConfig.IsBasicAuthnEnabled(): + authenticated, err = amw.basicAuthn(ctlr, userAc, response, request) + + // The authorization header is given but basic auth is not enabled + case !isAuthorizationHeaderEmpty(request) && !authConfig.IsBasicAuthnEnabled(): + authenticated = false + + // The session header is an explicit attempt to use session authentication + case hasSessionHeader(request): + authenticated, err = amw.sessionAuthn(ctlr, userAc, response, request) if err != nil { - response.WriteHeader(http.StatusInternalServerError) - - return + break } - if authenticated { - next.ServeHTTP(response, request) + // If session authentication fails, but anonymous or management access is allowed, + // treat the request as authenticated. This fallback is necessary because the session + // header may be present for anonymous or management requests. + authenticated = authenticated || allowAnonymous || isMgmtRequested - return - } - } else if hasSessionHeader(request) { - // try session auth - //nolint: contextcheck - authenticated, err := amw.sessionAuthn(ctlr, userAc, response, request) - if err != nil { - if errors.Is(err, zerr.ErrUserDataNotFound) { - ctlr.Log.Err(err).Msg("failed to find user profile in DB") + // Try mTLS authentication if client certificates are present + case ctlr.Config.IsMTLSAuthEnabled() && request.TLS != nil && len(request.TLS.PeerCertificates) > 0: + authenticated, err = amw.mTLSAuthn(ctlr, userAc, request) - authFail(response, request, realm, delay) - } + // If no auth methods enabled at all - then just authenticate anything + case !authConfig.IsBasicAuthnEnabled() && !ctlr.Config.IsMTLSAuthEnabled(): + authenticated = true - response.WriteHeader(http.StatusInternalServerError) + // If no credentials provided - check for anonymous / mgmt requests + case allowAnonymous || isMgmtRequested: + authenticated = true + } - return - } - - if authenticated { - next.ServeHTTP(response, request) - - return - } - - // the session header can be present also for anonymous calls - if allowAnonymous || isMgmtRequested { - next.ServeHTTP(response, request) - - return - } - } else if allowAnonymous || isMgmtRequested { - // try anonymous auth only if basic auth/session was not given - next.ServeHTTP(response, request) + // If error occurred during authn process - return 500 error + if err != nil { + response.WriteHeader(http.StatusInternalServerError) return } - authFail(response, request, realm, delay) + if authenticated { + next.ServeHTTP(response, request) + } else { + authFail(response, request, realm, delay) + } }) } } @@ -447,6 +504,15 @@ func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc { return } + // Reject requests with multiple Authorization headers as a security measure + if hasMultipleAuthorizationHeaders(request) { + ctlr.Log.Error().Msg("failed to parse Authorization header: multiple Authorization headers detected") + response.Header().Set("Content-Type", "application/json") + zcommon.WriteJSON(response, http.StatusUnauthorized, apiErr.NewError(apiErr.UNSUPPORTED)) + + return + } + acCtrlr := NewAccessController(ctlr.Config) // we want to bypass auth for mgmt route @@ -504,50 +570,6 @@ func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc { } } -func noPasswdAuth(ctlr *Controller) mux.MiddlewareFunc { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - if request.Method == http.MethodOptions { - next.ServeHTTP(response, request) - response.WriteHeader(http.StatusNoContent) - - return - } - - userAc := reqCtx.NewUserAccessControl() - - // if no basic auth enabled then try to get identity from mTLS auth - if request.TLS != nil { - verifiedChains := request.TLS.VerifiedChains - if len(verifiedChains) > 0 && len(verifiedChains[0]) > 0 { - for _, cert := range request.TLS.PeerCertificates { - identity := cert.Subject.CommonName - if identity != "" { - // assign identity to authz context, needed for extensions - userAc.SetUsername(identity) - } - } - } - } - - if ctlr.Config.IsMTLSAuthEnabled() && userAc.IsAnonymous() { - authConfig := ctlr.Config.CopyAuthConfig() - failDelay := authConfig.GetFailDelay() - realm := ctlr.Config.GetRealm() - - authFail(response, request, realm, failDelay) - - return - } - - userAc.SaveOnRequest(request) - - // Process request - next.ServeHTTP(response, request) - }) - } -} - func (rh *RouteHandler) AuthURLHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() @@ -732,6 +754,14 @@ func isAuthorizationHeaderEmpty(request *http.Request) bool { return false } +// hasMultipleAuthorizationHeaders checks if the request has multiple Authorization headers. +// This is a security concern as it could be used to bypass authentication or cause confusion. +func hasMultipleAuthorizationHeaders(request *http.Request) bool { + authHeaders := request.Header.Values("Authorization") + + return len(authHeaders) > 1 +} + func hasSessionHeader(request *http.Request) bool { clientHeader := request.Header.Get(constants.SessionClientHeaderName) diff --git a/pkg/api/authn_test.go b/pkg/api/authn_test.go index 56407634..20e8c093 100644 --- a/pkg/api/authn_test.go +++ b/pkg/api/authn_test.go @@ -6,7 +6,9 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/base64" "encoding/json" + "encoding/pem" "errors" "io/fs" "net/http" @@ -18,6 +20,7 @@ import ( "github.com/alicebob/miniredis/v2" guuid "github.com/gofrs/uuid" + "github.com/golang-jwt/jwt/v5" godigest "github.com/opencontainers/go-digest" "github.com/project-zot/mockoidc" . "github.com/smartystreets/goconvey/convey" @@ -39,6 +42,7 @@ import ( authutils "zotregistry.dev/zot/v2/pkg/test/auth" test "zotregistry.dev/zot/v2/pkg/test/common" "zotregistry.dev/zot/v2/pkg/test/mocks" + tlsutils "zotregistry.dev/zot/v2/pkg/test/tls" ) var ErrUnexpectedError = errors.New("error: unexpected error") @@ -864,6 +868,970 @@ func TestAPIKeys(t *testing.T) { }) } +func TestMTLSAuthentication(t *testing.T) { + // Create temporary directory for certificates + tempDir := t.TempDir() + + // Generate CA certificate + caCert, caKey, err := tlsutils.GenerateCACert() + if err != nil { + panic(err) + } + caCertPath := path.Join(tempDir, "ca.crt") + err = os.WriteFile(caCertPath, caCert, 0o600) + if err != nil { + panic(err) + } + + // Generate server certificate + serverCertPath := path.Join(tempDir, "server.crt") + serverKeyPath := path.Join(tempDir, "server.key") + opts := &tlsutils.CertificateOptions{ + Hostname: "localhost", + } + err = tlsutils.GenerateServerCertToFile(caCert, caKey, serverCertPath, serverKeyPath, opts) + if err != nil { + panic(err) + } + + // Generate valid client certificate for "testuser" user + clientCertPath := path.Join(tempDir, "client.crt") + clientKeyPath := path.Join(tempDir, "client.key") + clientOpts := &tlsutils.CertificateOptions{ + CommonName: "testuser", + } + err = tlsutils.GenerateClientCertToFile(caCert, caKey, clientCertPath, clientKeyPath, clientOpts) + if err != nil { + panic(err) + } + + // Generate self-signed client cert for "testuser" user + selfSignedClientCertPath := path.Join(tempDir, "client-selfsigned.crt") + selfSignedClientKeyPath := path.Join(tempDir, "client-selfsigned.key") + selfSignedOpts := &tlsutils.CertificateOptions{ + CommonName: "testuser", + } + err = tlsutils.GenerateClientSelfSignedCertToFile(selfSignedClientCertPath, selfSignedClientKeyPath, selfSignedOpts) + if err != nil { + panic(err) + } + + // Create htpasswd file with sample "httpuser" + htpasswdPath := test.MakeHtpasswdFileFromString(t, test.GetBcryptCredString("httpuser", "httppass")) + defer os.Remove(htpasswdPath) + + Convey("Test mTLS-only authentication", t, func() { + // Set up server + conf := config.New() + port := test.GetFreePort() + baseURL := test.GetSecureBaseURL(port) + + conf.HTTP.Port = port + conf.HTTP.TLS = &config.TLSConfig{ + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, + } + conf.HTTP.AccessControl = &config.AccessControlConfig{ + Groups: config.Groups{ + "mtls-users": config.Group{ + Users: []string{"testuser"}, + }, + }, + Repositories: config.Repositories{ + "**": config.PolicyGroup{ // Default restrict all + AnonymousPolicy: make([]string, 0), + Policies: make([]config.Policy, 0), + }, + "test-repo": config.PolicyGroup{ + Policies: []config.Policy{ + { + Users: []string{"testuser"}, + Actions: []string{"read", "create"}, + }, + }, + }, + }, + } + conf.Storage.RootDirectory = t.TempDir() + + ctlr := api.NewController(conf) + cm := test.NewControllerManager(ctlr) + + cm.StartAndWait(port) + defer cm.StopServer() + + // Test without client certificate - should fail + caCertPEM, err := os.ReadFile(caCertPath) + So(err, ShouldBeNil) + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCertPEM) + + client := resty.New() + client.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS13}) + resp, err := client.R().Get(baseURL + "/v2/test-repo/tags/list") + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + + // Test with valid client certificate - should succeed + clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) + So(err, ShouldBeNil) + + client = resty.New() + client.SetTLSClientConfig(&tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{clientCert}, + RootCAs: caCertPool, + }) + + resp, err = client.R().Get(baseURL + "/v2/test-repo/tags/list") + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) // 404 meaning we successfully passed auth + + // Test with self-signed client certificate - should fail + selfSignedClientCert, err := tls.LoadX509KeyPair(selfSignedClientCertPath, selfSignedClientKeyPath) + So(err, ShouldBeNil) + + client = resty.New() + client.SetTLSClientConfig(&tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{selfSignedClientCert}, + RootCAs: caCertPool, + }) + + resp, err = client.R().Get(baseURL + "/v2/test-selfsigned-repo/tags/list") + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + + Convey("Test mTLS with basic auth and user/group access policies", t, func() { + // Set up server + conf := config.New() + port := test.GetFreePort() + baseURL := test.GetSecureBaseURL(port) + + conf.HTTP.Port = port + conf.HTTP.TLS = &config.TLSConfig{ + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, + } + conf.HTTP.Auth = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{ + Path: htpasswdPath, + }, + } + conf.HTTP.AccessControl = &config.AccessControlConfig{ + Groups: config.Groups{ + "mtls-users": config.Group{ + Users: []string{"testuser"}, + }, + }, + Repositories: config.Repositories{ + "**": config.PolicyGroup{ // Default restrict all + AnonymousPolicy: make([]string, 0), + Policies: make([]config.Policy, 0), + }, + "group-repo": config.PolicyGroup{ + Policies: []config.Policy{ + { + Groups: []string{"mtls-users"}, + Actions: []string{"read", "create"}, + }, + }, + }, + "test-repo": config.PolicyGroup{ + Policies: []config.Policy{ + { + Users: []string{"testuser"}, + Actions: []string{"read", "create"}, + }, + }, + }, + "htpasswd-repo": config.PolicyGroup{ + Policies: []config.Policy{ + { + Users: []string{"httpuser"}, + Actions: []string{"read", "create"}, + }, + }, + }, + }, + } + conf.Storage.RootDirectory = t.TempDir() + + ctlr := api.NewController(conf) + cm := test.NewControllerManager(ctlr) + + cm.StartAndWait(port) + defer cm.StopServer() + + // Load server CA certificate + caCertPEM, err := os.ReadFile(caCertPath) + So(err, ShouldBeNil) + + // Load self-signed client certificate + selfSignedClientCert, err := tls.LoadX509KeyPair(selfSignedClientCertPath, selfSignedClientKeyPath) + So(err, ShouldBeNil) + + // Load valid client certificate with CN "testuser" + clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) + So(err, ShouldBeNil) + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCertPEM) + + // Tests without client certificate + client := resty.New() + client.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS13}) + resp, err := client.R().SetBasicAuth("httpuser", "httppass").Get(baseURL + "/v2/htpasswd-repo/tags/list") + // Test without client CA but with htpasswd credentials - should pass because of valid htpasswd credentials + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) // 404 meaning we successfully passed auth + + // Tests with self-signed (== non-acceptable by server) client certificate + client = resty.New() + client.SetTLSClientConfig(&tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{selfSignedClientCert}, + RootCAs: caCertPool, + }) + + // Test with self-signed client certificate - should still pass because of correct htpasswd auth + resp, err = client.R().SetBasicAuth("httpuser", "httppass").Get(baseURL + "/v2/htpasswd-repo/tags/list") + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) // 404 meaning we successfully passed auth + + // Tests with valid client certificate + client = resty.New() + client.SetTLSClientConfig(&tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{clientCert}, + RootCAs: caCertPool, + }) + // Tests with valid client cert and creds - should fail with 403 due to no permissions for user from basic auth + // This validates that identity from basic auth has higher priority over mTLS identity + resp, err = client.R().SetBasicAuth("httpuser", "httppass").Get(baseURL + "/v2/test-repo/tags/list") + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusForbidden) + + // Test with correct auth credentials and different basic auth username from client certificate CN - should success + // This validates that identity from basic auth has higher priority over mTLS identity + resp, err = client.R().SetBasicAuth("httpuser", "httppass").Get(baseURL + "/v2/htpasswd-repo/tags/list") + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) // 404 meaning we successfully passed auth + + // Should have access to test-repo for identity from client-cert + resp, err = client.R().Get(baseURL + "/v2/test-repo/tags/list") + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) // 404 meaning we successfully passed auth + + // Should not have access to other repos for identity from client-cert + resp, err = client.R().Get(baseURL + "/v2/unauthorized-repo/tags/list") + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusForbidden) + + // Should have access to group-repo through group membership for identity from client-cert + resp, err = client.R().Get(baseURL + "/v2/group-repo/tags/list") + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) // 404 meaning we successfully passed auth + }) +} + +func TestMTLSAuthenticationWithCertificateChain(t *testing.T) { + // Create temporary directory for certificates + tempDir := t.TempDir() + + Convey("Test mTLS with certificate chain - uses leaf certificate identity", t, func() { + // Create certificate chain: Root CA -> Intermediate CA -> Client Certificate + // Generate root CA + rootCACert, rootCAKey, err := tlsutils.GenerateCACert() + So(err, ShouldBeNil) + rootCACertPath := path.Join(tempDir, "root-ca.crt") + err = os.WriteFile(rootCACertPath, rootCACert, 0o600) + So(err, ShouldBeNil) + + // Generate intermediate CA (signed by root CA) + intermediateCAOpts := &tlsutils.CertificateOptions{ + CommonName: "Intermediate CA", + } + intermediateCACert, intermediateCAKeyPEM, err := tlsutils.GenerateIntermediateCACert( + rootCACert, rootCAKey, intermediateCAOpts) + So(err, ShouldBeNil) + + // Generate client certificate with CN signed by intermediate CA + clientWithCNOpts := &tlsutils.CertificateOptions{ + CommonName: "clientuser", + } + clientCertWithCN, clientKeyWithCN, err := tlsutils.GenerateClientCert( + intermediateCACert, intermediateCAKeyPEM, clientWithCNOpts) + So(err, ShouldBeNil) + + // Generate client certificate without CN signed by intermediate CA + clientWithoutCNOpts := &tlsutils.CertificateOptions{ + // No CommonName - empty to test that identity is not taken from intermediate CA + } + clientCertWithoutCN, clientKeyWithoutCNPEM, err := tlsutils.GenerateClientCert( + intermediateCACert, intermediateCAKeyPEM, clientWithoutCNOpts) + So(err, ShouldBeNil) + + // Generate server certificate signed by root CA for this test + serverCertForChainPath := path.Join(tempDir, "server-chain.crt") + serverKeyForChainPath := path.Join(tempDir, "server-chain.key") + serverOpts := &tlsutils.CertificateOptions{ + Hostname: "localhost", + } + err = tlsutils.GenerateServerCertToFile( + rootCACert, rootCAKey, serverCertForChainPath, serverKeyForChainPath, serverOpts) + So(err, ShouldBeNil) + + // Set up server with root CA + conf := config.New() + port := test.GetFreePort() + baseURL := test.GetSecureBaseURL(port) + + conf.HTTP.Port = port + conf.HTTP.TLS = &config.TLSConfig{ + Cert: serverCertForChainPath, + Key: serverKeyForChainPath, + CACert: rootCACertPath, // Server trusts root CA + } + conf.HTTP.AccessControl = &config.AccessControlConfig{ + Repositories: config.Repositories{ + "**": config.PolicyGroup{ + AnonymousPolicy: make([]string, 0), + Policies: make([]config.Policy, 0), + }, + "client-repo": config.PolicyGroup{ + Policies: []config.Policy{ + { + Users: []string{"clientuser"}, + Actions: []string{"read", "create"}, + }, + }, + }, + }, + } + conf.Storage.RootDirectory = t.TempDir() + + ctlr := api.NewController(conf) + cm := test.NewControllerManager(ctlr) + + cm.StartAndWait(port) + defer cm.StopServer() + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(rootCACert) + + // Test 1: Client cert with CN in chain - should use client cert CN, not intermediate CA CN + clientCertWithCNPath := path.Join(tempDir, "client-with-cn.crt") + clientKeyWithCNPath := path.Join(tempDir, "client-with-cn.key") + err = os.WriteFile(clientCertWithCNPath, clientCertWithCN, 0o600) + So(err, ShouldBeNil) + err = os.WriteFile(clientKeyWithCNPath, clientKeyWithCN, 0o600) + So(err, ShouldBeNil) + + // Create certificate chain file (client cert + intermediate CA) + chainCertPath := path.Join(tempDir, "client-with-cn-chain.crt") + err = tlsutils.WriteCertificateChainToFile(chainCertPath, clientCertWithCN, intermediateCACert) + So(err, ShouldBeNil) + + // Load certificate chain + clientCertChain, err := tls.LoadX509KeyPair(chainCertPath, clientKeyWithCNPath) + So(err, ShouldBeNil) + + client := resty.New() + client.SetTLSClientConfig(&tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{clientCertChain}, + RootCAs: caCertPool, + }) + + // Should succeed because client cert has CN "clientuser" which matches policy + resp, err := client.R().Get(baseURL + "/v2/client-repo/tags/list") + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) // 404 means auth passed + + // Test 2: Client cert without CN in chain - should fail, not use intermediate CA CN + clientCertWithoutCNPath := path.Join(tempDir, "client-without-cn.crt") + clientKeyWithoutCNPath := path.Join(tempDir, "client-without-cn.key") + err = os.WriteFile(clientCertWithoutCNPath, clientCertWithoutCN, 0o600) + So(err, ShouldBeNil) + err = os.WriteFile(clientKeyWithoutCNPath, clientKeyWithoutCNPEM, 0o600) + So(err, ShouldBeNil) + + // Create certificate chain file (client cert without CN + intermediate CA) + chainCertWithoutCNPath := path.Join(tempDir, "client-without-cn-chain.crt") + err = tlsutils.WriteCertificateChainToFile(chainCertWithoutCNPath, clientCertWithoutCN, intermediateCACert) + So(err, ShouldBeNil) + + // Load certificate chain + clientCertChainWithoutCN, err := tls.LoadX509KeyPair(chainCertWithoutCNPath, clientKeyWithoutCNPath) + So(err, ShouldBeNil) + + client = resty.New() + client.SetTLSClientConfig(&tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{clientCertChainWithoutCN}, + RootCAs: caCertPool, + }) + + // Should fail because client cert has no CN, even though intermediate CA has CN + resp, err = client.R().Get(baseURL + "/v2/client-repo/tags/list") + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) +} + +func TestMTLSAuthenticationWithExpiredCertificate(t *testing.T) { + // Create temporary directory for certificates + tempDir := t.TempDir() + + Convey("Test mTLS authentication with expired certificate", t, func() { + // Generate CA certificate + caCert, caKey, err := tlsutils.GenerateCACert() + So(err, ShouldBeNil) + caCertPath := path.Join(tempDir, "ca.crt") + err = os.WriteFile(caCertPath, caCert, 0o600) + So(err, ShouldBeNil) + + // Generate server certificate + serverCertPath := path.Join(tempDir, "server.crt") + serverKeyPath := path.Join(tempDir, "server.key") + opts := &tlsutils.CertificateOptions{ + Hostname: "localhost", + } + err = tlsutils.GenerateServerCertToFile(caCert, caKey, serverCertPath, serverKeyPath, opts) + So(err, ShouldBeNil) + + // Generate expired client certificate (NotAfter is in the past) + expiredClientCertPath := path.Join(tempDir, "client-expired.crt") + expiredClientKeyPath := path.Join(tempDir, "client-expired.key") + expiredOpts := &tlsutils.CertificateOptions{ + CommonName: "testuser", + NotBefore: time.Now().Add(-365 * 24 * time.Hour), // 1 year ago + NotAfter: time.Now().Add(-24 * time.Hour), // 1 day ago (expired) + } + err = tlsutils.GenerateClientCertToFile(caCert, caKey, expiredClientCertPath, expiredClientKeyPath, expiredOpts) + So(err, ShouldBeNil) + + // Set up server + conf := config.New() + port := test.GetFreePort() + baseURL := test.GetSecureBaseURL(port) + + conf.HTTP.Port = port + conf.HTTP.TLS = &config.TLSConfig{ + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, + } + conf.HTTP.AccessControl = &config.AccessControlConfig{ + Repositories: config.Repositories{ + "**": config.PolicyGroup{ + AnonymousPolicy: make([]string, 0), + Policies: make([]config.Policy, 0), + }, + "test-repo": config.PolicyGroup{ + Policies: []config.Policy{ + { + Users: []string{"testuser"}, + Actions: []string{"read", "create"}, + }, + }, + }, + }, + } + conf.Storage.RootDirectory = t.TempDir() + + ctlr := api.NewController(conf) + cm := test.NewControllerManager(ctlr) + + cm.StartAndWait(port) + defer cm.StopServer() + + // Set up client with expired certificate + caCertPEM, err := os.ReadFile(caCertPath) + So(err, ShouldBeNil) + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCertPEM) + + expiredClientCert, err := tls.LoadX509KeyPair(expiredClientCertPath, expiredClientKeyPath) + So(err, ShouldBeNil) + + client := resty.New() + client.SetTLSClientConfig(&tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{expiredClientCert}, + RootCAs: caCertPool, + }) + + // Expired certificate should be rejected at TLS handshake level + // The TLS stack will reject it before it reaches the application layer + _, err = client.R().Get(baseURL + "/v2/test-repo/tags/list") + // Error is expected - TLS handshake fails with expired certificate + So(err, ShouldNotBeNil) + So(err.Error(), ShouldContainSubstring, "expired certificate") + }) +} + +func TestMTLSAuthenticationWithUnknownCA(t *testing.T) { + // Create temporary directory for certificates + tempDir := t.TempDir() + + Convey("Test mTLS authentication with certificate signed by unknown CA", t, func() { + // Generate server CA and certificate + serverCACert, serverCAKey, err := tlsutils.GenerateCACert() + So(err, ShouldBeNil) + serverCACertPath := path.Join(tempDir, "server-ca.crt") + err = os.WriteFile(serverCACertPath, serverCACert, 0o600) + So(err, ShouldBeNil) + + serverCertPath := path.Join(tempDir, "server.crt") + serverKeyPath := path.Join(tempDir, "server.key") + opts := &tlsutils.CertificateOptions{ + Hostname: "localhost", + } + err = tlsutils.GenerateServerCertToFile(serverCACert, serverCAKey, serverCertPath, serverKeyPath, opts) + So(err, ShouldBeNil) + + // Generate a different CA (unknown to the server) and client certificate + unknownCACert, unknownCAKey, err := tlsutils.GenerateCACert() + So(err, ShouldBeNil) + + unknownClientCertPath := path.Join(tempDir, "client-unknown-ca.crt") + unknownClientKeyPath := path.Join(tempDir, "client-unknown-ca.key") + clientOpts := &tlsutils.CertificateOptions{ + CommonName: "testuser", + } + err = tlsutils.GenerateClientCertToFile(unknownCACert, unknownCAKey, unknownClientCertPath, + unknownClientKeyPath, clientOpts) + So(err, ShouldBeNil) + + // Set up server with server CA (doesn't know about unknown CA) + conf := config.New() + port := test.GetFreePort() + baseURL := test.GetSecureBaseURL(port) + + conf.HTTP.Port = port + conf.HTTP.TLS = &config.TLSConfig{ + Cert: serverCertPath, + Key: serverKeyPath, + CACert: serverCACertPath, // Server only trusts serverCACert, not unknownCACert + } + conf.HTTP.AccessControl = &config.AccessControlConfig{ + Repositories: config.Repositories{ + "**": config.PolicyGroup{ + AnonymousPolicy: make([]string, 0), + Policies: make([]config.Policy, 0), + }, + "test-repo": config.PolicyGroup{ + Policies: []config.Policy{ + { + Users: []string{"testuser"}, + Actions: []string{"read", "create"}, + }, + }, + }, + }, + } + conf.Storage.RootDirectory = t.TempDir() + + ctlr := api.NewController(conf) + cm := test.NewControllerManager(ctlr) + + cm.StartAndWait(port) + defer cm.StopServer() + + // Set up client with certificate signed by unknown CA + serverCACertPEM, err := os.ReadFile(serverCACertPath) + So(err, ShouldBeNil) + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(serverCACertPEM) + + unknownClientCert, err := tls.LoadX509KeyPair(unknownClientCertPath, unknownClientKeyPath) + So(err, ShouldBeNil) + + client := resty.New() + client.SetTLSClientConfig(&tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{unknownClientCert}, + RootCAs: caCertPool, + }) + + // Certificate signed by unknown CA should be rejected at TLS handshake level + // The TLS stack will reject it before it reaches the application layer + _, err = client.R().Get(baseURL + "/v2/test-repo/tags/list") + // Error is expected - TLS handshake fails with unknown certificate authority + So(err, ShouldNotBeNil) + So(err.Error(), ShouldContainSubstring, "unknown certificate authority") + }) +} + +func TestMultipleAuthorizationHeaders(t *testing.T) { + Convey("Test rejection of multiple Authorization headers", t, func() { + Convey("Test multiple and single Authorization headers in basic auth handler", func() { + conf := config.New() + port := test.GetFreePort() + baseURL := test.GetBaseURL(port) + + username, _ := test.GenerateRandomString() + password, _ := test.GenerateRandomString() + + htpasswdPath := test.MakeHtpasswdFileFromString(t, test.GetBcryptCredString(username, password)) + + conf.HTTP.Port = port + conf.HTTP.Auth = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{ + Path: htpasswdPath, + }, + } + conf.Storage.RootDirectory = t.TempDir() + + ctlr := api.NewController(conf) + cm := test.NewControllerManager(ctlr) + + cm.StartAndWait(port) + defer cm.StopServer() + + Convey("Multiple Authorization headers should be rejected - basic first", func() { + // Create a request with multiple Authorization headers + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, baseURL+"/v2/_catalog", nil) + So(err, ShouldBeNil) + + // Add multiple Authorization headers + basicAuth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + req.Header.Add("Authorization", "Basic "+basicAuth) + req.Header.Add("Authorization", "Bearer token123") + + client := &http.Client{} + resp, err := client.Do(req) + So(err, ShouldBeNil) + defer resp.Body.Close() + + // Should be rejected with 401 Unauthorized + So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) + }) + + Convey("Multiple Authorization headers should be rejected - bearer first", func() { + // Create a request with multiple Authorization headers + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, baseURL+"/v2/_catalog", nil) + So(err, ShouldBeNil) + + // Add multiple Authorization headers + req.Header.Add("Authorization", "Bearer token123") + basicAuth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + req.Header.Add("Authorization", "Basic "+basicAuth) + + client := &http.Client{} + resp, err := client.Do(req) + So(err, ShouldBeNil) + defer resp.Body.Close() + + // Should be rejected with 401 Unauthorized + So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) + }) + + Convey("Multiple Authorization headers should be rejected - basic twice", func() { + // Create a request with multiple Authorization headers + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, baseURL+"/v2/_catalog", nil) + So(err, ShouldBeNil) + + // Add multiple Authorization headers with correct values + basicAuth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + req.Header.Add("Authorization", "Basic "+basicAuth) + req.Header.Add("Authorization", "Basic "+basicAuth) + + client := &http.Client{} + resp, err := client.Do(req) + So(err, ShouldBeNil) + defer resp.Body.Close() + + // Should be rejected with 401 Unauthorized + So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) + }) + + Convey("Single Authorization header should work", func() { + // Create a request with single Authorization header + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, baseURL+"/v2/_catalog", nil) + So(err, ShouldBeNil) + + // Add single Authorization header + basicAuth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + req.Header.Add("Authorization", "Basic "+basicAuth) + + client := &http.Client{} + resp, err := client.Do(req) + So(err, ShouldBeNil) + defer resp.Body.Close() + + // Should succeed + So(resp.StatusCode, ShouldEqual, http.StatusOK) + }) + }) + + Convey("Test multiple Authorization headers in bearer auth handler", func() { + tempDir := t.TempDir() + + // Generate CA certificate + caCert, caKey, err := tlsutils.GenerateCACert() + So(err, ShouldBeNil) + + // Generate server certificate for bearer auth + serverCertPath := path.Join(tempDir, "server.cert") + serverKeyPath := path.Join(tempDir, "server.key") + opts := &tlsutils.CertificateOptions{ + Hostname: "localhost", + } + err = tlsutils.GenerateServerCertToFile(caCert, caKey, serverCertPath, serverKeyPath, opts) + So(err, ShouldBeNil) + + conf := config.New() + port := test.GetFreePort() + baseURL := test.GetBaseURL(port) + + conf.HTTP.Port = port + conf.HTTP.Auth = &config.AuthConfig{ + Bearer: &config.BearerConfig{ + Cert: serverCertPath, + Realm: "test-realm", + Service: "test-service", + }, + } + conf.Storage.RootDirectory = t.TempDir() + + ctlr := api.NewController(conf) + cm := test.NewControllerManager(ctlr) + + cm.StartAndWait(port) + defer cm.StopServer() + + // Load the private key to sign the token + keyBytes, err := os.ReadFile(serverKeyPath) + So(err, ShouldBeNil) + + keyBlock, _ := pem.Decode(keyBytes) + So(keyBlock, ShouldNotBeNil) + + privateKey, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes) + So(err, ShouldBeNil) + + // Create a valid JWT token with proper claims + // For /v2/_catalog, the requestedAccess will have Name="" (no repository name in URL) + // So we need to provide access to repository with empty name or use wildcard + claims := &api.ClaimsWithAccess{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + Access: []api.ResourceAccess{ + { + Type: "repository", + Name: "", // Empty name matches /v2/_catalog + Actions: []string{"pull"}, + }, + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + validTokenString, err := token.SignedString(privateKey) + So(err, ShouldBeNil) + + Convey("Multiple Authorization headers should be rejected - bearer and basic - bearer first", func() { + // Create a request with multiple Authorization headers + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, baseURL+"/v2/_catalog", nil) + So(err, ShouldBeNil) + + // Add multiple Authorization headers + req.Header.Add("Authorization", "Bearer "+validTokenString) + req.Header.Add("Authorization", "Basic dXNlcjpwYXNz") + + client := &http.Client{} + resp, err := client.Do(req) + So(err, ShouldBeNil) + defer resp.Body.Close() + + // Should be rejected with 401 Unauthorized + So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) + }) + + Convey("Multiple Authorization headers should be rejected - bearer and basic - basic first", func() { + // Create a request with multiple Authorization headers + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, baseURL+"/v2/_catalog", nil) + So(err, ShouldBeNil) + + // Add multiple Authorization headers + req.Header.Add("Authorization", "Basic dXNlcjpwYXNz") + req.Header.Add("Authorization", "Bearer "+validTokenString) + + client := &http.Client{} + resp, err := client.Do(req) + So(err, ShouldBeNil) + defer resp.Body.Close() + + // Should be rejected with 401 Unauthorized + So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) + }) + + Convey("Multiple Authorization headers should be rejected - two bearer headers", func() { + // Create a request with multiple Authorization headers + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, baseURL+"/v2/_catalog", nil) + So(err, ShouldBeNil) + + // Add multiple bearer Authorization headers + req.Header.Add("Authorization", "Bearer "+validTokenString) + req.Header.Add("Authorization", "Bearer "+validTokenString) + + client := &http.Client{} + resp, err := client.Do(req) + So(err, ShouldBeNil) + defer resp.Body.Close() + + // Should be rejected with 401 Unauthorized + So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) + }) + + Convey("Single Authorization header should work - invalid token", func() { + // Create a request with single Authorization header + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, baseURL+"/v2/_catalog", nil) + So(err, ShouldBeNil) + + // Add single bearer Authorization header with invalid token + req.Header.Add("Authorization", "Bearer token123") + + client := &http.Client{} + resp, err := client.Do(req) + So(err, ShouldBeNil) + defer resp.Body.Close() + + // The token is invalid, so we expect 401, but not due to multiple headers + So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) + }) + + Convey("Single Authorization header should work - correct bearer token", func() { + // Create a request with single Authorization header + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, baseURL+"/v2/_catalog", nil) + So(err, ShouldBeNil) + + // Add single bearer Authorization header with valid token + req.Header.Add("Authorization", "Bearer "+validTokenString) + + client := &http.Client{} + resp, err := client.Do(req) + So(err, ShouldBeNil) + defer resp.Body.Close() + + // Should succeed with valid token + So(resp.StatusCode, ShouldEqual, http.StatusOK) + }) + }) + }) +} + +func TestMTLSAuthenticationWithMetaDBError(t *testing.T) { + // Create temporary directory for certificates + tempDir := t.TempDir() + + Convey("Test mTLS authentication with MetaDB.SetUserGroups error", t, func() { + // Generate CA certificate + caCert, caKey, err := tlsutils.GenerateCACert() + So(err, ShouldBeNil) + caCertPath := path.Join(tempDir, "ca.crt") + err = os.WriteFile(caCertPath, caCert, 0o600) + So(err, ShouldBeNil) + + // Generate server certificate + serverCertPath := path.Join(tempDir, "server.crt") + serverKeyPath := path.Join(tempDir, "server.key") + opts := &tlsutils.CertificateOptions{ + Hostname: "localhost", + } + err = tlsutils.GenerateServerCertToFile(caCert, caKey, serverCertPath, serverKeyPath, opts) + So(err, ShouldBeNil) + + // Generate valid client certificate for "testuser" user + clientCertPath := path.Join(tempDir, "client.crt") + clientKeyPath := path.Join(tempDir, "client.key") + clientOpts := &tlsutils.CertificateOptions{ + CommonName: "testuser", + } + err = tlsutils.GenerateClientCertToFile(caCert, caKey, clientCertPath, clientKeyPath, clientOpts) + So(err, ShouldBeNil) + + // Set up server + conf := config.New() + port := test.GetFreePort() + baseURL := test.GetSecureBaseURL(port) + + conf.HTTP.Port = port + conf.HTTP.TLS = &config.TLSConfig{ + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, + } + conf.HTTP.AccessControl = &config.AccessControlConfig{ + Groups: config.Groups{ + "mtls-users": config.Group{ + Users: []string{"testuser"}, + }, + }, + Repositories: config.Repositories{ + "**": config.PolicyGroup{ + AnonymousPolicy: make([]string, 0), + Policies: make([]config.Policy, 0), + }, + "test-repo": config.PolicyGroup{ + Policies: []config.Policy{ + { + Users: []string{"testuser"}, + Actions: []string{"read", "create"}, + }, + }, + }, + }, + } + conf.Storage.RootDirectory = t.TempDir() + + ctlr := api.NewController(conf) + cm := test.NewControllerManager(ctlr) + + cm.StartAndWait(port) + defer cm.StopServer() + + // Set up client with valid certificate + caCertPEM, err := os.ReadFile(caCertPath) + So(err, ShouldBeNil) + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCertPEM) + + clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) + So(err, ShouldBeNil) + + client := resty.New() + client.SetTLSClientConfig(&tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{clientCert}, + RootCAs: caCertPool, + }) + + // Mock MetaDB to return error on SetUserGroups + ctlr.MetaDB = mocks.MetaDBMock{ + SetUserGroupsFn: func(ctx context.Context, groups []string) error { + return ErrUnexpectedError + }, + } + + // Should return 500 Internal Server Error due to MetaDB error + resp, err := client.R().Get(baseURL + "/v2/test-repo/tags/list") + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + }) +} + func TestAPIKeysOpenDBError(t *testing.T) { Convey("Test API keys - unable to create database", t, func() { conf := config.New() diff --git a/pkg/api/config/config.go b/pkg/api/config/config.go index 12f305fd..2b9555e9 100644 --- a/pkg/api/config/config.go +++ b/pkg/api/config/config.go @@ -524,57 +524,6 @@ func (c *Config) isTagsRetentionEnabled(tagRetentionPolicy KeepTagsPolicy) bool return false } -// isBasicAuthnEnabled checks if any basic authentication method is enabled (internal, no locking). -func (c *Config) isBasicAuthnEnabled() bool { - if c == nil { - return false - } - - // Check HTPasswd - if c.HTTP.Auth != nil && c.HTTP.Auth.HTPasswd.Path != "" { - return true - } - - // Check LDAP - if c.HTTP.Auth != nil && c.HTTP.Auth.LDAP != nil { - return true - } - - // Check API Key - if c.HTTP.Auth != nil && c.HTTP.Auth.APIKey { - return true - } - - // Check OpenID - if c.HTTP.Auth != nil && c.HTTP.Auth.OpenID != nil { - for provider := range c.HTTP.Auth.OpenID.Providers { - if isOpenIDAuthProviderEnabled(c, provider) { - return true - } - } - } - - return false -} - -// isOpenIDAuthProviderEnabled checks if a specific OpenID provider is enabled (internal use only). -func isOpenIDAuthProviderEnabled(config *Config, provider string) bool { - if providerConfig, ok := config.HTTP.Auth.OpenID.Providers[provider]; ok { - if IsOpenIDSupported(provider) { - if providerConfig.ClientID != "" || providerConfig.Issuer != "" || - len(providerConfig.Scopes) > 0 { - return true - } - } else if IsOauth2Supported(provider) { - if providerConfig.ClientID != "" || len(providerConfig.Scopes) > 0 { - return true - } - } - } - - return false -} - // Sanitize makes a sanitized copy of the config removing any secrets. func (c *Config) Sanitize() *Config { if c == nil { @@ -999,12 +948,11 @@ func (c *Config) IsMTLSAuthEnabled() bool { c.mu.RLock() defer c.mu.RUnlock() + // mTLS is enabled if TLS is configured with client CA certificates if c.HTTP.TLS != nil && c.HTTP.TLS.Key != "" && c.HTTP.TLS.Cert != "" && - c.HTTP.TLS.CACert != "" && - !c.isBasicAuthnEnabled() && - !c.HTTP.AccessControl.AnonymousPolicyExists() { + c.HTTP.TLS.CACert != "" { return true } diff --git a/pkg/api/config/config_test.go b/pkg/api/config/config_test.go index ee5e912a..f33e7771 100644 --- a/pkg/api/config/config_test.go +++ b/pkg/api/config/config_test.go @@ -1760,237 +1760,6 @@ func TestConfig(t *testing.T) { }, } So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) - - // Test with HTPasswd enabled (should disable mTLS) - cfg = &config.Config{ - HTTP: config.HTTPConfig{ - Auth: &config.AuthConfig{ - HTPasswd: config.AuthHTPasswd{ - Path: "/path/to/htpasswd", - }, - }, - TLS: &config.TLSConfig{ - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", - CACert: "/path/to/ca-cert.pem", - }, - }, - } - So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled - - // Test with LDAP enabled (should disable mTLS) - cfg = &config.Config{ - HTTP: config.HTTPConfig{ - Auth: &config.AuthConfig{ - LDAP: &config.LDAPConfig{}, - }, - TLS: &config.TLSConfig{ - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", - CACert: "/path/to/ca-cert.pem", - }, - }, - } - So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled - - // Test with API Key enabled (should disable mTLS) - cfg = &config.Config{ - HTTP: config.HTTPConfig{ - Auth: &config.AuthConfig{ - APIKey: true, - }, - TLS: &config.TLSConfig{ - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", - CACert: "/path/to/ca-cert.pem", - }, - }, - } - So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled - - // Test with OpenID enabled (valid config - should disable mTLS) - cfg = &config.Config{ - HTTP: config.HTTPConfig{ - Auth: &config.AuthConfig{ - OpenID: &config.OpenIDConfig{ - Providers: map[string]config.OpenIDProviderConfig{ - "google": { - ClientID: "client-id", - Issuer: "", - Scopes: []string{}, - }, - }, - }, - }, - TLS: &config.TLSConfig{ - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", - CACert: "/path/to/ca-cert.pem", - }, - }, - } - So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled - - // Test with OpenID enabled (with Issuer - should disable mTLS) - cfg = &config.Config{ - HTTP: config.HTTPConfig{ - Auth: &config.AuthConfig{ - OpenID: &config.OpenIDConfig{ - Providers: map[string]config.OpenIDProviderConfig{ - "google": { - ClientID: "", - Issuer: "https://accounts.google.com", - Scopes: []string{}, - }, - }, - }, - }, - TLS: &config.TLSConfig{ - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", - CACert: "/path/to/ca-cert.pem", - }, - }, - } - So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled - - // Test with OpenID enabled (with Scopes - should disable mTLS) - cfg = &config.Config{ - HTTP: config.HTTPConfig{ - Auth: &config.AuthConfig{ - OpenID: &config.OpenIDConfig{ - Providers: map[string]config.OpenIDProviderConfig{ - "google": { - ClientID: "", - Issuer: "", - Scopes: []string{"openid", "email"}, - }, - }, - }, - }, - TLS: &config.TLSConfig{ - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", - CACert: "/path/to/ca-cert.pem", - }, - }, - } - So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled - - // Test with OAuth2 provider (github) with ClientID (should disable mTLS) - cfg = &config.Config{ - HTTP: config.HTTPConfig{ - Auth: &config.AuthConfig{ - OpenID: &config.OpenIDConfig{ - Providers: map[string]config.OpenIDProviderConfig{ - "github": { - ClientID: "github-client-id", - Scopes: []string{}, - }, - }, - }, - }, - TLS: &config.TLSConfig{ - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", - CACert: "/path/to/ca-cert.pem", - }, - }, - } - So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled - - // Test with OAuth2 provider (github) with Scopes (should disable mTLS) - cfg = &config.Config{ - HTTP: config.HTTPConfig{ - Auth: &config.AuthConfig{ - OpenID: &config.OpenIDConfig{ - Providers: map[string]config.OpenIDProviderConfig{ - "github": { - ClientID: "", - Scopes: []string{"user:email"}, - }, - }, - }, - }, - TLS: &config.TLSConfig{ - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", - CACert: "/path/to/ca-cert.pem", - }, - }, - } - So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) // Basic auth enabled, so mTLS disabled - - // Test with OpenID but empty config (should enable mTLS) - cfg = &config.Config{ - HTTP: config.HTTPConfig{ - Auth: &config.AuthConfig{ - OpenID: &config.OpenIDConfig{ - Providers: map[string]config.OpenIDProviderConfig{ - "google": { - ClientID: "", - Issuer: "", - Scopes: []string{}, - }, - }, - }, - }, - TLS: &config.TLSConfig{ - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", - CACert: "/path/to/ca-cert.pem", - }, - }, - } - So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) // No basic auth, so mTLS enabled - - // Test with OpenID but unsupported provider (should enable mTLS) - cfg = &config.Config{ - HTTP: config.HTTPConfig{ - Auth: &config.AuthConfig{ - OpenID: &config.OpenIDConfig{ - Providers: map[string]config.OpenIDProviderConfig{ - "unsupported": { - ClientID: "client-id", - Scopes: []string{"scope"}, - }, - }, - }, - }, - TLS: &config.TLSConfig{ - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", - CACert: "/path/to/ca-cert.pem", - }, - }, - } - So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) // No basic auth, so mTLS enabled - - // Test with no authentication methods (should enable mTLS) - cfg = &config.Config{ - HTTP: config.HTTPConfig{ - Auth: &config.AuthConfig{}, - TLS: &config.TLSConfig{ - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", - CACert: "/path/to/ca-cert.pem", - }, - }, - } - So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) // No basic auth, so mTLS enabled - - // Test with nil Auth (should enable mTLS) - cfg = &config.Config{ - HTTP: config.HTTPConfig{ - Auth: nil, - TLS: &config.TLSConfig{ - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", - CACert: "/path/to/ca-cert.pem", - }, - }, - } - So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) // No basic auth, so mTLS enabled }) Convey("Test UseSecureSession()", func() { @@ -2341,88 +2110,6 @@ func TestConfig(t *testing.T) { }) }) - Convey("Test isOpenIDAuthProviderEnabled indirectly via IsMTLSAuthEnabled", t, func() { - Convey("Test with OpenID provider with empty config", func() { - cfg := &config.Config{ - HTTP: config.HTTPConfig{ - TLS: &config.TLSConfig{ - Key: "key", - Cert: "cert", - CACert: "cacert", - }, - Auth: &config.AuthConfig{ - OpenID: &config.OpenIDConfig{ - Providers: map[string]config.OpenIDProviderConfig{ - "google": { - ClientID: "", - Issuer: "", - Scopes: []string{}, - }, - }, - }, - }, - AccessControl: &config.AccessControlConfig{}, - }, - } - // This should return true because isOpenIDAuthProviderEnabled returns false for empty config, - // so isBasicAuthnEnabled returns false, making IsMTLSAuthEnabled return true - So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) - }) - - Convey("Test with OpenID provider with valid config", func() { - cfg := &config.Config{ - HTTP: config.HTTPConfig{ - TLS: &config.TLSConfig{ - Key: "key", - Cert: "cert", - CACert: "cacert", - }, - Auth: &config.AuthConfig{ - OpenID: &config.OpenIDConfig{ - Providers: map[string]config.OpenIDProviderConfig{ - "google": { - ClientID: "client-id", - Issuer: "", - Scopes: []string{}, - }, - }, - }, - }, - AccessControl: &config.AccessControlConfig{}, - }, - } - // This should return false because isOpenIDAuthProviderEnabled returns true for valid config, - // so isBasicAuthnEnabled returns true, making IsMTLSAuthEnabled return false - So(cfg.IsMTLSAuthEnabled(), ShouldBeFalse) - }) - - Convey("Test with unsupported OpenID provider", func() { - cfg := &config.Config{ - HTTP: config.HTTPConfig{ - TLS: &config.TLSConfig{ - Key: "key", - Cert: "cert", - CACert: "cacert", - }, - Auth: &config.AuthConfig{ - OpenID: &config.OpenIDConfig{ - Providers: map[string]config.OpenIDProviderConfig{ - "unsupported": { - ClientID: "client-id", - Scopes: []string{"scope"}, - }, - }, - }, - }, - AccessControl: &config.AccessControlConfig{}, - }, - } - // This should return true because isOpenIDAuthProviderEnabled returns false for unsupported provider, - // so isBasicAuthnEnabled returns false, making IsMTLSAuthEnabled return true - So(cfg.IsMTLSAuthEnabled(), ShouldBeTrue) - }) - }) - Convey("Test nil receiver coverage for all methods", t, func() { Convey("Test AuthConfig methods with nil receiver", func() { var authConfig *config.AuthConfig = nil diff --git a/pkg/api/controller.go b/pkg/api/controller.go index 7580540b..3efff12a 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -240,11 +240,6 @@ func (c *Controller) Run() error { } if tlsConfig.CACert != "" { - clientAuth := tls.VerifyClientCertIfGiven - if c.Config.IsMTLSAuthEnabled() { - clientAuth = tls.RequireAndVerifyClientCert - } - caCert, err := os.ReadFile(tlsConfig.CACert) if err != nil { c.Log.Error().Err(err).Str("caCert", tlsConfig.CACert).Msg("failed to read file") @@ -260,7 +255,9 @@ func (c *Controller) Run() error { return errors.ErrBadCACert } - server.TLSConfig.ClientAuth = clientAuth + // Use VerifyClientCertIfGiven even if mTLS is enabled: clients without cert will be treated as anonymous + // You can control permissions for mTLS anonymous requests via accessControl policies + server.TLSConfig.ClientAuth = tls.VerifyClientCertIfGiven server.TLSConfig.ClientCAs = caCertPool } diff --git a/pkg/api/controller_test.go b/pkg/api/controller_test.go index 7c9e4212..7d233caf 100644 --- a/pkg/api/controller_test.go +++ b/pkg/api/controller_test.go @@ -2255,41 +2255,44 @@ func TestMutualTLSAuthWithUserPermissions(t *testing.T) { cert, err := tls.LoadX509KeyPair("../../test/data/client.cert", "../../test/data/client.key") So(err, ShouldBeNil) - resty.SetCertificates(cert) - - defer func() { resty.SetCertificates(tls.Certificate{}) }() + // Use separate resty client with certificates, because we cannot perform cleanup with resty.SetCertificates() + client := resty.New().SetTLSClientConfig(&tls.Config{ + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + }) // with client certs but without creds, should succeed - resp, err = resty.R().Get(secureBaseURL + "/v2/") + resp, err = client.R().Get(secureBaseURL + "/v2/") So(err, ShouldBeNil) So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusOK) - resp, err = resty.R().Get(secureBaseURL + "/v2/_catalog") + resp, err = client.R().Get(secureBaseURL + "/v2/_catalog") So(err, ShouldBeNil) So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusOK) // with creds, should get expected status code - resp, _ = resty.R().Get(secureBaseURL) + resp, _ = client.R().Get(secureBaseURL) So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) // reading a repo should not get 403 - resp, err = resty.R().Get(secureBaseURL + "/v2/repo/tags/list") + resp, err = client.R().Get(secureBaseURL + "/v2/repo/tags/list") So(err, ShouldBeNil) So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) // without creds, writes should fail - resp, err = resty.R().Post(secureBaseURL + "/v2/repo/blobs/uploads/") + resp, err = client.R().Post(secureBaseURL + "/v2/repo/blobs/uploads/") So(err, ShouldBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusForbidden) // empty default authorization and give user the permission to create repoPolicy.Policies[0].Actions = append(repoPolicy.Policies[0].Actions, "create") conf.HTTP.AccessControl.Repositories[test.AuthorizationAllRepos] = repoPolicy - resp, err = resty.R().Post(secureBaseURL + "/v2/repo/blobs/uploads/") + resp, err = client.R().Post(secureBaseURL + "/v2/repo/blobs/uploads/") So(err, ShouldBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusAccepted) }) @@ -2531,12 +2534,15 @@ func TestMutualTLSAuthWithoutCN(t *testing.T) { cert, err := tls.LoadX509KeyPair("../../test/data/noidentity/client.cert", "../../test/data/noidentity/client.key") So(err, ShouldBeNil) - resty.SetCertificates(cert) - - defer func() { resty.SetCertificates(tls.Certificate{}) }() + // Use separate resty client with certificates, because we cannot perform cleanup with resty.SetCertificates() + client := resty.New().SetTLSClientConfig(&tls.Config{ + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + }) // with client certs but without TLS mutual auth setup should get certificate error - resp, _ := resty.R().Get(secureBaseURL + "/v2/_catalog") + resp, _ := client.R().Get(secureBaseURL + "/v2/_catalog") So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) }) } @@ -2553,10 +2559,6 @@ func TestTLSMutualAuth(t *testing.T) { baseURL := test.GetBaseURL(port) secureBaseURL := test.GetSecureBaseURL(port) - resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}) - - defer func() { resty.SetTLSClientConfig(nil) }() - conf := config.New() conf.HTTP.Port = port conf.HTTP.TLS = &config.TLSConfig{ @@ -2572,46 +2574,78 @@ func TestTLSMutualAuth(t *testing.T) { defer cm.StopServer() + // access without any certificate settings + client := resty.New() + // accessing insecure HTTP site should fail - resp, err := resty.R().Get(baseURL) + resp, err := client.R().Get(baseURL) So(err, ShouldBeNil) So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusBadRequest) - // without client certs and creds, should get conn error - _, err = resty.R().Get(secureBaseURL) + // without client certs and creds, should get certificate verification error + _, err = client.R().Get(secureBaseURL) So(err, ShouldNotBeNil) + // without client certs should fail auth + _, err = client.R().Get(secureBaseURL + "/v2/") + So(err, ShouldNotBeNil) + + // Use resty client with certificates, + client = resty.New().SetTLSClientConfig(&tls.Config{ + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + }) + + // without client certs should fail auth + resp, err = client.R().Get(secureBaseURL) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) + + // without client certs should fail auth + resp, _ = client.R().Get(secureBaseURL + "/v2/") + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + username, seedUser := test.GenerateRandomString() password, seedPass := test.GenerateRandomString() ctlr.Log.Info().Int64("seedUser", seedUser).Int64("seedPass", seedPass).Msg("random seed for username & password") - // with creds but without certs, should get conn error - _, err = resty.R().SetBasicAuth(username, password).Get(secureBaseURL) - So(err, ShouldNotBeNil) + + resp, err = client.R().SetBasicAuth(username, password).Get(secureBaseURL) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) + + // with only creds, should get 401 because basic auth is disabled + // (Authorization header should be rejected when the auth method is disabled, regardless of mTLS) + resp, _ = client.R().SetBasicAuth(username, password).Get(secureBaseURL + "/v2/") + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) // setup TLS mutual auth cert, err := tls.LoadX509KeyPair("../../test/data/client.cert", "../../test/data/client.key") So(err, ShouldBeNil) - resty.SetCertificates(cert) - - defer func() { resty.SetCertificates(tls.Certificate{}) }() + client = resty.New().SetTLSClientConfig(&tls.Config{ + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + }) // with client certs but without creds, should succeed - resp, err = resty.R().Get(secureBaseURL + "/v2/") + resp, err = client.R().Get(secureBaseURL + "/v2/") So(err, ShouldBeNil) So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusOK) - // with client certs and creds, should get expected status code - resp, _ = resty.R().SetBasicAuth(username, password).Get(secureBaseURL) + resp, _ = client.R().SetBasicAuth(username, password).Get(secureBaseURL) So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) - // with client certs, creds shouldn't matter - resp, _ = resty.R().SetBasicAuth(username, password).Get(secureBaseURL + "/v2/") + // with client certs and creds, should get 401 because basic auth is disabled + // (Authorization header should be rejected when the auth method is disabled, regardless of mTLS) + resp, _ = client.R().SetBasicAuth(username, password).Get(secureBaseURL + "/v2/") So(resp, ShouldNotBeNil) - So(resp.StatusCode(), ShouldEqual, http.StatusOK) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) }) } @@ -2718,9 +2752,11 @@ func TestTLSMutualAuthAllowReadAccess(t *testing.T) { baseURL := test.GetBaseURL(port) secureBaseURL := test.GetSecureBaseURL(port) - resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}) - - defer func() { resty.SetTLSClientConfig(nil) }() + // Use resty client with certificates, + client := resty.New().SetTLSClientConfig(&tls.Config{ + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + }) conf := config.New() conf.HTTP.Port = port @@ -2746,26 +2782,28 @@ func TestTLSMutualAuthAllowReadAccess(t *testing.T) { defer cm.StopServer() // accessing insecure HTTP site should fail - resp, err := resty.R().Get(baseURL) + resp, err := client.R().Get(baseURL) So(err, ShouldBeNil) So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusBadRequest) // without client certs and creds, reads are allowed - resp, err = resty.R().Get(secureBaseURL + "/v2/") + resp, err = client.R().Get(secureBaseURL + "/v2/") So(err, ShouldBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusOK) username, seedUser := test.GenerateRandomString() password, seedPass := test.GenerateRandomString() + ctlr.Log.Info().Int64("seedUser", seedUser).Int64("seedPass", seedPass).Msg("random seed for username & password") - // with creds but without certs, reads are allowed - resp, err = resty.R().SetBasicAuth(username, password).Get(secureBaseURL + "/v2/") + // with creds but without certs, reads are not allowed as server does not use basic auth + // and basic auth headers are expected to contain valid credentials + resp, err = client.R().SetBasicAuth(username, password).Get(secureBaseURL + "/v2/") So(err, ShouldBeNil) - So(resp.StatusCode(), ShouldEqual, http.StatusOK) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) // without creds, writes should fail - resp, err = resty.R().Post(secureBaseURL + "/v2/repo/blobs/uploads/") + resp, err = client.R().Post(secureBaseURL + "/v2/repo/blobs/uploads/") So(err, ShouldBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) @@ -2773,25 +2811,34 @@ func TestTLSMutualAuthAllowReadAccess(t *testing.T) { cert, err := tls.LoadX509KeyPair("../../test/data/client.cert", "../../test/data/client.key") So(err, ShouldBeNil) - resty.SetCertificates(cert) - - defer func() { resty.SetCertificates(tls.Certificate{}) }() + // Use separate resty client with certificates, because we cannot perform cleanup with resty.SetCertificates() + client = resty.New().SetTLSClientConfig(&tls.Config{ + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + }) // with client certs but without creds, should succeed - resp, err = resty.R().Get(secureBaseURL + "/v2/") + resp, _ = client.R().Get(secureBaseURL) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) + + // with client certs but without creds, should succeed + resp, err = client.R().Get(secureBaseURL + "/v2/") So(err, ShouldBeNil) So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusOK) - // with client certs and creds, should get expected status code - resp, _ = resty.R().SetBasicAuth(username, password).Get(secureBaseURL) + // with client certs and creds, reads are not allowed as server does not use basic auth + // and basic auth headers are expected to contain valid credentials + resp, _ = client.R().SetBasicAuth(username, password).Get(secureBaseURL) So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) - // with client certs, creds shouldn't matter - resp, _ = resty.R().SetBasicAuth(username, password).Get(secureBaseURL + "/v2/") + // with client certs, reads are not allowed as server does not use basic auth + resp, _ = client.R().SetBasicAuth(username, password).Get(secureBaseURL + "/v2/") So(resp, ShouldNotBeNil) - So(resp.StatusCode(), ShouldEqual, http.StatusOK) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) }) } @@ -2859,22 +2906,25 @@ func TestTLSMutualAndBasicAuth(t *testing.T) { cert, err := tls.LoadX509KeyPair("../../test/data/client.cert", "../../test/data/client.key") So(err, ShouldBeNil) - resty.SetCertificates(cert) + // Use separate resty client with certificates, because we cannot perform cleanup with resty.SetCertificates() + client := resty.New().SetTLSClientConfig(&tls.Config{ + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + }) - defer func() { resty.SetCertificates(tls.Certificate{}) }() - - // with client certs but without creds, should get access error - resp, err = resty.R().Get(secureBaseURL + "/v2/") + // with client certs but without creds, succeed because mTLS is used for auth when no auth headers provided + resp, err = client.R().Get(secureBaseURL + "/v2/") So(err, ShouldBeNil) So(resp, ShouldNotBeNil) - So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) // with client certs and creds, should get expected status code - resp, _ = resty.R().SetBasicAuth(username, password).Get(secureBaseURL) + resp, _ = client.R().SetBasicAuth(username, password).Get(secureBaseURL) So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) - resp, _ = resty.R().SetBasicAuth(username, password).Get(secureBaseURL + "/v2/") + resp, _ = client.R().SetBasicAuth(username, password).Get(secureBaseURL + "/v2/") So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusOK) }) @@ -2952,26 +3002,29 @@ func TestTLSMutualAndBasicAuthAllowReadAccess(t *testing.T) { cert, err := tls.LoadX509KeyPair("../../test/data/client.cert", "../../test/data/client.key") So(err, ShouldBeNil) - resty.SetCertificates(cert) - - defer func() { resty.SetCertificates(tls.Certificate{}) }() + // Use separate resty client with certificates, because we cannot perform cleanup with resty.SetCertificates() + client := resty.New().SetTLSClientConfig(&tls.Config{ + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + }) // with client certs but without creds, reads should succeed - resp, err = resty.R().Get(secureBaseURL + "/v2/") + resp, err = client.R().Get(secureBaseURL + "/v2/") So(err, ShouldBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusOK) - // with only client certs, writes should fail - resp, err = resty.R().Post(secureBaseURL + "/v2/repo/blobs/uploads/") + // with only client certs, writes should fail with insufficient permissions + resp, err = client.R().Post(secureBaseURL + "/v2/repo/blobs/uploads/") So(err, ShouldBeNil) - So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + So(resp.StatusCode(), ShouldEqual, http.StatusForbidden) // with client certs and creds, should get expected status code - resp, _ = resty.R().SetBasicAuth(username, password).Get(secureBaseURL) + resp, _ = client.R().SetBasicAuth(username, password).Get(secureBaseURL) So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) - resp, _ = resty.R().SetBasicAuth(username, password).Get(secureBaseURL + "/v2/") + resp, _ = client.R().SetBasicAuth(username, password).Get(secureBaseURL + "/v2/") So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusOK) }) diff --git a/pkg/cli/client/client_test.go b/pkg/cli/client/client_test.go index 5f7410c5..c97ad439 100644 --- a/pkg/cli/client/client_test.go +++ b/pkg/cli/client/client_test.go @@ -102,7 +102,8 @@ func TestTLSWithAuth(t *testing.T) { So(err, ShouldNotBeNil) So(imageBuff.String(), ShouldContainSubstring, "scheme not provided") - args = []string{"list", "--config", "imagetest"} + invalidUser := fmt.Sprintf("%s:%s", "wrong_username", "wrong_password") + args = []string{"-u", invalidUser, "list", "--config", "imagetest"} _ = makeConfigFile(t, fmt.Sprintf(`{"configs":[{"_name":"imagetest","url":"%s%s%s","showspinner":false}]}`, diff --git a/pkg/cli/server/root.go b/pkg/cli/server/root.go index 84fa8809..e335c9a4 100644 --- a/pkg/cli/server/root.go +++ b/pkg/cli/server/root.go @@ -708,8 +708,9 @@ func validateAuthzPolicies(config *config.Config, logger zlog.Logger) error { logger.Info().Msg("checking if anonymous authorization is the only type of authorization policy configured") - if !authConfig.IsBasicAuthnEnabled() && !accessControlConfig.ContainsOnlyAnonymousPolicy() { - msg := "access control config requires one of httpasswd, ldap or openid authentication " + + if !authConfig.IsBasicAuthnEnabled() && !config.IsMTLSAuthEnabled() && + !accessControlConfig.ContainsOnlyAnonymousPolicy() { + msg := "access control config requires one of htpasswd, ldap, openid or mTLS authentication " + "or using only 'anonymousPolicy' policies" logger.Error().Err(zerr.ErrBadConfig).Msg(msg) diff --git a/pkg/test/tls/tls.go b/pkg/test/tls/tls.go new file mode 100644 index 00000000..a776b99d --- /dev/null +++ b/pkg/test/tls/tls.go @@ -0,0 +1,407 @@ +package tls + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "net" + "os" + "time" +) + +var ( + ErrDecodeCAPEM = errors.New("failed to decode CA certificate PEM") + ErrInvalidCertificateType = errors.New("invalid certificate type") + ErrCertificateOptionsRequired = errors.New("CertificateOptions is required") + ErrHostnameRequired = errors.New("Hostname is required in CertificateOptions") + ErrNoCertificatesProvided = errors.New("at least one certificate is required") +) + +const ( + certTypeCA = "CA" + certTypeServer = "Server" + certTypeClient = "Client" +) + +// CertificateOptions contains optional settings for certificate generation. +// If a field is nil or zero, default values will be used. +type CertificateOptions struct { + // NotBefore is the certificate validity start time. + // If zero, defaults to time.Now(). + NotBefore time.Time + + // NotAfter is the certificate validity end time. + // If zero, defaults will be used based on certificate type. + NotAfter time.Time + + // DNSNames contains the DNS names for the Subject Alternative Name extension. + // If nil, default values may be used based on certificate type. + DNSNames []string + + // IPAddresses contains the IP addresses for the Subject Alternative Name extension. + // If nil, default values may be used based on certificate type. + IPAddresses []net.IP + + // EmailAddresses contains the email addresses for the Subject Alternative Name extension. + // If nil, no email addresses will be included. + EmailAddresses []string + + // Hostname is the hostname or IP address for server certificates. + // For server certificates, this is required and will be added to DNSNames or IPAddresses + // based on whether it's a valid IP address or a DNS name. + Hostname string + + // CommonName is the CommonName (CN) for client certificates. + // For client certificates, this is optional - if not provided, the certificate will not have a CN. + CommonName string +} + +// generateCertificate is a helper function that generates a certificate and private key. +// If signerCert and signerKey are nil, the certificate will be self-signed. +func generateCertificate( + certType string, + opts *CertificateOptions, + signerCert *x509.Certificate, + signerKey *rsa.PrivateKey, +) ([]byte, []byte, error) { + // Generate private key + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate private key: %w", err) + } + + // Initialize certificate template + template, err := initializeTemplate(certType) + if err != nil { + return nil, nil, err + } + + // Apply options + applyOptions(template, opts, certType) + + // Determine signer (self-signed if signerCert is nil) + var issuerCert *x509.Certificate + + var issuerKey *rsa.PrivateKey + if signerCert == nil { + // Self-signed + issuerCert = template + issuerKey = privKey + } else { + // Signed by CA + issuerCert = signerCert + issuerKey = signerKey + } + + // Create the certificate + certDER, err := x509.CreateCertificate(rand.Reader, template, issuerCert, &privKey.PublicKey, issuerKey) + if err != nil { + return nil, nil, fmt.Errorf("failed to create certificate: %w", err) + } + + // Encode certificate to PEM + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + // Encode private key to PEM + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privKey), + }) + + return certPEM, keyPEM, nil +} + +// parseCA parses CA certificate and private key from PEM format. +func parseCA(caCertPEM, caKeyPEM []byte) (*x509.Certificate, *rsa.PrivateKey, error) { + // Parse CA certificate + caCertBlock, _ := pem.Decode(caCertPEM) + if caCertBlock == nil { + return nil, nil, ErrDecodeCAPEM + } + + caCert, err := x509.ParseCertificate(caCertBlock.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse CA certificate: %w", err) + } + + // Parse CA private key + caKeyBlock, _ := pem.Decode(caKeyPEM) + if caKeyBlock == nil { + return nil, nil, ErrDecodeCAPEM + } + + caPrivKey, err := x509.ParsePKCS1PrivateKey(caKeyBlock.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse CA private key: %w", err) + } + + return caCert, caPrivKey, nil +} + +// initializeTemplate creates and initializes a certificate template based on the certificate type. +// certType can be "CA", "Server", or "Client". +func initializeTemplate(certType string) (*x509.Certificate, error) { + template := &x509.Certificate{} + + // Initialize certificate type-specific fields and defaults + switch certType { + case certTypeCA: + template.IsCA = true + template.ExtKeyUsage = []x509.ExtKeyUsage{} + template.KeyUsage = x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign + template.BasicConstraintsValid = true + template.Subject = pkix.Name{ + Organization: []string{"Test CA"}, + Country: []string{"US"}, + Province: []string{""}, + Locality: []string{"San Francisco"}, + StreetAddress: []string{""}, + PostalCode: []string{""}, + } + template.NotBefore = time.Now() + template.NotAfter = time.Now().AddDate(10, 0, 0) // 10 years for CA + case certTypeServer: + template.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} + template.KeyUsage = x509.KeyUsageDigitalSignature + template.Subject = pkix.Name{ + Organization: []string{"Test Server"}, + Country: []string{"US"}, + Province: []string{""}, + Locality: []string{"San Francisco"}, + StreetAddress: []string{""}, + PostalCode: []string{""}, + } + template.NotBefore = time.Now() + template.NotAfter = time.Now().AddDate(1, 0, 0) // 1 year for server + template.IPAddresses = []net.IP{net.ParseIP("127.0.0.1")} // Default IP for Server + case certTypeClient: + template.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} + template.KeyUsage = x509.KeyUsageDigitalSignature + template.Subject = pkix.Name{ + Organization: []string{"Test Client"}, + Country: []string{"US"}, + Province: []string{""}, + Locality: []string{"San Francisco"}, + StreetAddress: []string{""}, + PostalCode: []string{""}, + } + template.NotBefore = time.Now() + template.NotAfter = time.Now().AddDate(1, 0, 0) // 1 year for client + default: + return nil, fmt.Errorf("%w: %s", ErrInvalidCertificateType, certType) + } + + return template, nil +} + +// applyOptions applies options to the certificate template, using defaults when options are not provided. +// certType can be "CA", "Server", or "Client". +func applyOptions(template *x509.Certificate, opts *CertificateOptions, certType string) { + if opts == nil { + opts = &CertificateOptions{} + } + + // Apply NotBefore if provided in options + if !opts.NotBefore.IsZero() { + template.NotBefore = opts.NotBefore + } + + // Apply NotAfter if provided in options + if !opts.NotAfter.IsZero() { + template.NotAfter = opts.NotAfter + } + + // Apply SAN (Subject Alternative Name) - handle IPAddresses + // Priority: 1) opts.IPAddresses, 2) hostname if IP, 3) keep default from initializeTemplate + if opts.IPAddresses != nil { + template.IPAddresses = opts.IPAddresses + } else if certType == certTypeServer && opts.Hostname != "" { + if ip := net.ParseIP(opts.Hostname); ip != nil { + // Hostname is an IP address, use it + template.IPAddresses = []net.IP{ip} + } + // If hostname is DNS name, keep default IP from initializeTemplate + } + + // Apply SAN (Subject Alternative Name) - handle DNSNames + // Priority: 1) opts.DNSNames, 2) hostname if DNS name + if opts.DNSNames != nil { + template.DNSNames = opts.DNSNames + } else if certType == certTypeServer && opts.Hostname != "" { + if ip := net.ParseIP(opts.Hostname); ip == nil { + // Hostname is a DNS name, use it + template.DNSNames = []string{opts.Hostname} + } + } + + // Apply email addresses + if opts.EmailAddresses != nil { + template.EmailAddresses = opts.EmailAddresses + } + + // Apply CommonName - explicitly set to empty string if not provided to ensure it's empty + if opts.CommonName != "" { + template.Subject.CommonName = opts.CommonName + } else { + template.Subject.CommonName = "" + } +} + +// GenerateCACert generates a CA certificate and private key. +// opts is optional and can be used to customize certificate settings. +func GenerateCACert(opts ...*CertificateOptions) ([]byte, []byte, error) { + var options *CertificateOptions + if len(opts) > 0 && opts[0] != nil { + options = opts[0] + } + + // Self-signed certificate (signerCert and signerKey are nil) + return generateCertificate(certTypeCA, options, nil, nil) +} + +// GenerateIntermediateCACert generates an intermediate CA certificate signed by the provided parent CA. +// opts is optional and can be used to customize certificate settings, including CommonName. +func GenerateIntermediateCACert( + parentCACertPEM, + parentCAKeyPEM []byte, + opts ...*CertificateOptions, +) ([]byte, []byte, error) { + var options *CertificateOptions + if len(opts) > 0 && opts[0] != nil { + options = opts[0] + } else { + options = &CertificateOptions{} + } + + // Parse parent CA certificate and key + parentCACert, parentCAPrivKey, err := parseCA(parentCACertPEM, parentCAKeyPEM) + if err != nil { + return nil, nil, err + } + + // Generate intermediate CA certificate signed by parent CA + return generateCertificate(certTypeCA, options, parentCACert, parentCAPrivKey) +} + +// writeCertAndKeyToFile writes certificate and key bytes to their respective files. +func writeCertAndKeyToFile(certPath, keyPath string, certBytes, keyBytes []byte) error { + err := os.WriteFile(certPath, certBytes, 0o600) + if err != nil { + return err + } + + return os.WriteFile(keyPath, keyBytes, 0o600) +} + +// WriteCertificateChainToFile writes a certificate chain to a file. +// The certificates should be provided in order: leaf certificate first, followed by intermediate CAs. +// All certificates should be in PEM format. +func WriteCertificateChainToFile(certChainPath string, certs ...[]byte) error { + if len(certs) == 0 { + return ErrNoCertificatesProvided + } + + // Calculate total size for pre-allocation + totalSize := 0 + for _, cert := range certs { + totalSize += len(cert) + } + + // Concatenate all certificates + chainPEM := make([]byte, 0, totalSize) + for _, cert := range certs { + chainPEM = append(chainPEM, cert...) + } + + // Write to file + err := os.WriteFile(certChainPath, chainPEM, 0o600) + if err != nil { + return err + } + + return nil +} + +// GenerateServerCert generates a server certificate signed by the provided CA. +// opts is required and must contain a Hostname field. +func GenerateServerCert(caCertPEM, caKeyPEM []byte, opts *CertificateOptions) ([]byte, []byte, error) { + if opts == nil || opts.Hostname == "" { + return nil, nil, ErrHostnameRequired + } + + // Parse CA certificate and key + caCert, caPrivKey, err := parseCA(caCertPEM, caKeyPEM) + if err != nil { + return nil, nil, err + } + + // Generate certificate signed by CA + return generateCertificate(certTypeServer, opts, caCert, caPrivKey) +} + +// GenerateServerCertToFile generates a server certificate signed by the provided CA +// and writes generated key and cert to files. +// opts is required and must contain a Hostname field. +func GenerateServerCertToFile( + caCertPEM, caKeyPEM []byte, + certOutputPath, keyOutputPath string, + opts *CertificateOptions, +) error { + serverCertBytes, serverKeyBytes, err := GenerateServerCert(caCertPEM, caKeyPEM, opts) + if err != nil { + return err + } + + return writeCertAndKeyToFile(certOutputPath, keyOutputPath, serverCertBytes, serverKeyBytes) +} + +// GenerateClientCert generates a client certificate signed by the provided CA. +// opts is optional. CommonName is optional - if not provided, the certificate will not have a CN. +func GenerateClientCert(caCertPEM, caKeyPEM []byte, opts *CertificateOptions) ([]byte, []byte, error) { + // Parse CA certificate and key + caCert, caPrivKey, err := parseCA(caCertPEM, caKeyPEM) + if err != nil { + return nil, nil, err + } + + // Generate certificate signed by CA + return generateCertificate(certTypeClient, opts, caCert, caPrivKey) +} + +// GenerateClientCertToFile generates a client certificate signed by the provided CA +// and writes generated key and cert to files. +// opts is optional. CommonName is optional - if not provided, the certificate will not have a CN. +func GenerateClientCertToFile(caCertPEM, caKeyPEM []byte, certPath, keyPath string, opts *CertificateOptions) error { + clientCertBytes, clientKeyBytes, err := GenerateClientCert(caCertPEM, caKeyPEM, opts) + if err != nil { + return err + } + + return writeCertAndKeyToFile(certPath, keyPath, clientCertBytes, clientKeyBytes) +} + +// GenerateClientSelfSignedCert generates a client certificate not signed by any CA. +// opts is optional. CommonName is optional - if not provided, the certificate will not have a CN. +func GenerateClientSelfSignedCert(opts *CertificateOptions) ([]byte, []byte, error) { + // Self-signed certificate (signerCert and signerKey are nil) + return generateCertificate(certTypeClient, opts, nil, nil) +} + +// GenerateClientSelfSignedCertToFile generates a client certificate not signed by any CA +// and writes generated key and cert to files. +// opts is optional. CommonName is optional - if not provided, the certificate will not have a CN. +func GenerateClientSelfSignedCertToFile(certOutputPath, keyOutputPath string, opts *CertificateOptions) error { + clientCertBytes, clientKeyBytes, err := GenerateClientSelfSignedCert(opts) + if err != nil { + return err + } + + return writeCertAndKeyToFile(certOutputPath, keyOutputPath, clientCertBytes, clientKeyBytes) +} diff --git a/pkg/test/tls/tls_test.go b/pkg/test/tls/tls_test.go new file mode 100644 index 00000000..e5bceba7 --- /dev/null +++ b/pkg/test/tls/tls_test.go @@ -0,0 +1,438 @@ +package tls_test + +import ( + "crypto/x509" + "encoding/pem" + "net" + "path" + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" + + "zotregistry.dev/zot/v2/pkg/test/tls" +) + +func TestGenerateCACert(t *testing.T) { + Convey("Generate CA certificate", t, func() { + certPEM, keyPEM, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + Convey("Certificate should be valid PEM", func() { + certBlock, _ := pem.Decode(certPEM) + So(certBlock, ShouldNotBeNil) + So(certBlock.Type, ShouldEqual, "CERTIFICATE") + + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + So(cert.IsCA, ShouldBeTrue) + So(cert.Subject.Organization[0], ShouldEqual, "Test CA") + }) + + Convey("Private key should be valid PEM", func() { + keyBlock, _ := pem.Decode(keyPEM) + So(keyBlock, ShouldNotBeNil) + So(keyBlock.Type, ShouldEqual, "RSA PRIVATE KEY") + + _, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes) + So(err, ShouldBeNil) + }) + }) +} + +func TestGenerateServerCert(t *testing.T) { + Convey("Generate server certificate", t, func() { + caCertPEM, caKeyPEM, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + Convey("With hostname", func() { + hostname := "localhost" + opts := &tls.CertificateOptions{ + Hostname: hostname, + } + certPEM, keyPEM, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + certBlock, _ := pem.Decode(certPEM) + So(certBlock, ShouldNotBeNil) + + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + So(cert.DNSNames, ShouldContain, hostname) + So(cert.ExtKeyUsage, ShouldContain, x509.ExtKeyUsageServerAuth) + + keyBlock, _ := pem.Decode(keyPEM) + So(keyBlock, ShouldNotBeNil) + }) + + Convey("With IP address", func() { + ipaddr := "127.0.0.1" + opts := &tls.CertificateOptions{ + Hostname: ipaddr, + } + certPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + certBlock, _ := pem.Decode(certPEM) + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + So(len(cert.IPAddresses), ShouldBeGreaterThan, 0) + So(cert.IPAddresses[0].String(), ShouldEqual, ipaddr) + }) + + Convey("With invalid CA PEM", func() { + invalidPEM := []byte("invalid pem") + opts := &tls.CertificateOptions{ + Hostname: "localhost", + } + _, _, err := tls.GenerateServerCert(invalidPEM, invalidPEM, opts) + So(err, ShouldEqual, tls.ErrDecodeCAPEM) + }) + }) +} + +func TestGenerateCertWithCN(t *testing.T) { + Convey("Generate client certificate with CN", t, func() { + caCertPEM, caKeyPEM, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + commonName := "test-client" + opts := &tls.CertificateOptions{ + CommonName: commonName, + } + certPEM, keyPEM, err := tls.GenerateClientCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + Convey("Certificate should have correct properties", func() { + certBlock, _ := pem.Decode(certPEM) + So(certBlock, ShouldNotBeNil) + + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + So(cert.Subject.CommonName, ShouldEqual, commonName) + So(cert.ExtKeyUsage, ShouldContain, x509.ExtKeyUsageClientAuth) + }) + + Convey("Private key should be valid", func() { + keyBlock, _ := pem.Decode(keyPEM) + So(keyBlock, ShouldNotBeNil) + }) + }) +} + +func TestGenerateSelfSignedCertWithCN(t *testing.T) { + Convey("Generate self-signed certificate with CN", t, func() { + commonName := "self-signed-client" + opts := &tls.CertificateOptions{ + CommonName: commonName, + } + certPEM, keyPEM, err := tls.GenerateClientSelfSignedCert(opts) + So(err, ShouldBeNil) + + Convey("Certificate should be self-signed", func() { + certBlock, _ := pem.Decode(certPEM) + So(certBlock, ShouldNotBeNil) + + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + So(cert.Subject.CommonName, ShouldEqual, commonName) + So(cert.Subject.String(), ShouldEqual, cert.Issuer.String()) + }) + + Convey("Certificate should have correct validity period", func() { + certBlock, _ := pem.Decode(certPEM) + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + So(cert.NotAfter.After(time.Now().AddDate(0, 11, 0)), ShouldBeTrue) + }) + + Convey("Private key should be valid", func() { + keyBlock, _ := pem.Decode(keyPEM) + So(keyBlock, ShouldNotBeNil) + + _, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes) + So(err, ShouldBeNil) + }) + }) +} + +func TestApplyOptionsCoverage(t *testing.T) { + Convey("Test applyOptions with various options", t, func() { + caCertPEM, caKeyPEM, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + Convey("Test with custom NotBefore and NotAfter", func() { + customNotBefore := time.Now().Add(-24 * time.Hour) + customNotAfter := time.Now().Add(2 * 365 * 24 * time.Hour) + + opts := &tls.CertificateOptions{ + Hostname: "localhost", + NotBefore: customNotBefore, + NotAfter: customNotAfter, + } + certPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + certBlock, _ := pem.Decode(certPEM) + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + So(cert.NotBefore.Unix(), ShouldEqual, customNotBefore.Unix()) + So(cert.NotAfter.Unix(), ShouldEqual, customNotAfter.Unix()) + // Verify Hostname is encoded in DNSNames (since "localhost" is a DNS name) + So(cert.DNSNames, ShouldContain, "localhost") + }) + + Convey("Test with explicit IPAddresses", func() { + customIPs := []net.IP{net.ParseIP("192.168.1.1"), net.ParseIP("10.0.0.1")} + opts := &tls.CertificateOptions{ + Hostname: "localhost", + IPAddresses: customIPs, + } + certPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + certBlock, _ := pem.Decode(certPEM) + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + So(len(cert.IPAddresses), ShouldEqual, 2) + So(cert.IPAddresses[0].String(), ShouldEqual, "192.168.1.1") + So(cert.IPAddresses[1].String(), ShouldEqual, "10.0.0.1") + // Verify explicit IPAddresses are used (not the Hostname IP) + So(cert.IPAddresses, ShouldNotContain, net.ParseIP("127.0.0.1")) + // Verify Hostname DNS name is still added to DNSNames when no explicit DNSNames provided + So(cert.DNSNames, ShouldContain, "localhost") + }) + + Convey("Test with explicit DNSNames", func() { + customDNS := []string{"example.com", "test.example.com"} + opts := &tls.CertificateOptions{ + Hostname: "localhost", + DNSNames: customDNS, + } + certPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + certBlock, _ := pem.Decode(certPEM) + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + So(len(cert.DNSNames), ShouldEqual, 2) + So(cert.DNSNames, ShouldContain, "example.com") + So(cert.DNSNames, ShouldContain, "test.example.com") + // Verify explicit DNSNames take precedence - Hostname should NOT be added + So(cert.DNSNames, ShouldNotContain, "localhost") + }) + + Convey("Test with EmailAddresses", func() { + customEmails := []string{"user@example.com", "admin@example.com"} + opts := &tls.CertificateOptions{ + Hostname: "localhost", + EmailAddresses: customEmails, + } + certPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + certBlock, _ := pem.Decode(certPEM) + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + So(len(cert.EmailAddresses), ShouldEqual, 2) + So(cert.EmailAddresses, ShouldContain, "user@example.com") + So(cert.EmailAddresses, ShouldContain, "admin@example.com") + }) + + Convey("Test with all options combined", func() { + customNotBefore := time.Now().Add(-12 * time.Hour) + customNotAfter := time.Now().Add(365 * 24 * time.Hour) + customIPs := []net.IP{net.ParseIP("192.168.1.100")} + customDNS := []string{"combined.example.com"} + customEmails := []string{"combined@example.com"} + + opts := &tls.CertificateOptions{ + Hostname: "localhost", + NotBefore: customNotBefore, + NotAfter: customNotAfter, + IPAddresses: customIPs, + DNSNames: customDNS, + EmailAddresses: customEmails, + } + certPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + certBlock, _ := pem.Decode(certPEM) + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + So(cert.NotBefore.Unix(), ShouldEqual, customNotBefore.Unix()) + So(cert.NotAfter.Unix(), ShouldEqual, customNotAfter.Unix()) + So(len(cert.IPAddresses), ShouldEqual, 1) + So(cert.IPAddresses[0].String(), ShouldEqual, "192.168.1.100") + So(len(cert.DNSNames), ShouldEqual, 1) + So(cert.DNSNames[0], ShouldEqual, "combined.example.com") + So(len(cert.EmailAddresses), ShouldEqual, 1) + So(cert.EmailAddresses[0], ShouldEqual, "combined@example.com") + // Verify explicit DNSNames take precedence - Hostname should NOT be added + So(cert.DNSNames, ShouldNotContain, "localhost") + }) + + Convey("Test Hostname as IP address is encoded in IPAddresses", func() { + ipHostname := "192.168.2.50" + opts := &tls.CertificateOptions{ + Hostname: ipHostname, + } + certPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + certBlock, _ := pem.Decode(certPEM) + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + // Verify Hostname IP is in IPAddresses + So(len(cert.IPAddresses), ShouldBeGreaterThan, 0) + So(cert.IPAddresses[0].String(), ShouldEqual, ipHostname) + // Verify it's NOT in DNSNames + So(cert.DNSNames, ShouldNotContain, ipHostname) + }) + + Convey("Test Hostname as DNS name is encoded in DNSNames", func() { + dnsHostname := "example.test" + opts := &tls.CertificateOptions{ + Hostname: dnsHostname, + } + certPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + certBlock, _ := pem.Decode(certPEM) + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + // Verify Hostname DNS is in DNSNames + So(cert.DNSNames, ShouldContain, dnsHostname) + }) + + Convey("Test with nil options (CA certificate)", func() { + // This tests the nil check in applyOptions + certPEM, _, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + certBlock, _ := pem.Decode(certPEM) + cert, err := x509.ParseCertificate(certBlock.Bytes) + So(err, ShouldBeNil) + So(cert.IsCA, ShouldBeTrue) + }) + }) +} + +func TestErrorPaths(t *testing.T) { + Convey("Test error paths", t, func() { + caCertPEM, caKeyPEM, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + Convey("Test parseCA with invalid cert PEM", func() { + invalidCertPEM := []byte("not a valid PEM") + _, _, err := tls.GenerateServerCert(invalidCertPEM, caKeyPEM, &tls.CertificateOptions{ + Hostname: "localhost", + }) + So(err, ShouldEqual, tls.ErrDecodeCAPEM) + }) + + Convey("Test parseCA with invalid key PEM", func() { + invalidKeyPEM := []byte("not a valid PEM") + _, _, err := tls.GenerateServerCert(caCertPEM, invalidKeyPEM, &tls.CertificateOptions{ + Hostname: "localhost", + }) + So(err, ShouldEqual, tls.ErrDecodeCAPEM) + }) + + Convey("Test GenerateServerCertToFile with nil opts", func() { + tempDir := t.TempDir() + certPath := path.Join(tempDir, "server.crt") + keyPath := path.Join(tempDir, "server.key") + + err := tls.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, nil) + So(err, ShouldEqual, tls.ErrHostnameRequired) + }) + + Convey("Test GenerateCACert with nil option", func() { + // Test when opts[0] == nil - should still work (uses default options) + certPEM, keyPEM, err := tls.GenerateCACert(nil) + So(err, ShouldBeNil) + So(certPEM, ShouldNotBeNil) + So(keyPEM, ShouldNotBeNil) + }) + + Convey("Test writeCertAndKeyToFile error when cert file write fails", func() { + tempDir := t.TempDir() + // Create a directory path instead of a file path to cause write error + certPath := tempDir // This is a directory, not a file + keyPath := path.Join(tempDir, "server.key") + + opts := &tls.CertificateOptions{ + Hostname: "localhost", + } + err := tls.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, opts) + So(err, ShouldNotBeNil) + }) + + Convey("Test writeCertAndKeyToFile error when key file write fails", func() { + tempDir := t.TempDir() + certPath := path.Join(tempDir, "server.crt") + // Create a directory path instead of a file path to cause write error + keyPath := tempDir // This is a directory, not a file + + opts := &tls.CertificateOptions{ + Hostname: "localhost", + } + err := tls.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, opts) + So(err, ShouldNotBeNil) + }) + + Convey("Test GenerateServerCertToFile error propagation", func() { + // Test that error from GenerateServerCert is propagated + tempDir := t.TempDir() + certPath := path.Join(tempDir, "server.crt") + keyPath := path.Join(tempDir, "server.key") + + // Use invalid CA to trigger error in GenerateServerCert + invalidPEM := []byte("invalid") + err := tls.GenerateServerCertToFile(invalidPEM, invalidPEM, certPath, keyPath, &tls.CertificateOptions{ + Hostname: "localhost", + }) + So(err, ShouldNotBeNil) + So(err, ShouldEqual, tls.ErrDecodeCAPEM) + }) + + Convey("Test GenerateClientCert with invalid PEM", func() { + // Test that parseCA error is propagated from GenerateClientCert + invalidCertPEM := []byte("not a valid PEM") + _, _, err := tls.GenerateClientCert(invalidCertPEM, caKeyPEM, nil) + So(err, ShouldEqual, tls.ErrDecodeCAPEM) + }) + + Convey("Test GenerateClientCertToFile error propagation", func() { + // Test that error from GenerateClientCert is propagated + tempDir := t.TempDir() + certPath := path.Join(tempDir, "client.crt") + keyPath := path.Join(tempDir, "client.key") + + // Use invalid CA to trigger error in GenerateClientCert + invalidPEM := []byte("invalid") + err := tls.GenerateClientCertToFile(invalidPEM, invalidPEM, certPath, keyPath, nil) + So(err, ShouldNotBeNil) + So(err, ShouldEqual, tls.ErrDecodeCAPEM) + }) + + Convey("Test GenerateIntermediateCACert with invalid PEM", func() { + // Test that parseCA error is propagated from GenerateIntermediateCACert + invalidCertPEM := []byte("not a valid PEM") + _, _, err := tls.GenerateIntermediateCACert(invalidCertPEM, caKeyPEM) + So(err, ShouldEqual, tls.ErrDecodeCAPEM) + }) + + Convey("Test GenerateClientSelfSignedCertToFile error propagation", func() { + // Test writeCertAndKeyToFile error path + tempDir := t.TempDir() + // Create a directory path instead of a file path to cause write error + certPath := tempDir // This is a directory, not a file + keyPath := path.Join(tempDir, "client.key") + + err := tls.GenerateClientSelfSignedCertToFile(certPath, keyPath, nil) + So(err, ShouldNotBeNil) + }) + }) +}