From cf8b0bdbf98a3c65ad675f15e477cfe81a00a166 Mon Sep 17 00:00:00 2001 From: Andrei Aaron Date: Sat, 13 Dec 2025 09:47:32 +0200 Subject: [PATCH] refactor: enhance TLS cert generation and refactor HTTP client architecture (#3638) - Refactored HTTP client from global cache to struct-based approach (global state was shared between tests, including what certificates to use) - Enhanced pkg/test/tls to support ECDSA and ED25519 key types - Replaced static certificate files with dynamic generation in golang tests - Fixed test cleanup issues and improved resource management This eliminates dependency on external cert generation scripts and improves test maintainability. Signed-off-by: Andrei Aaron --- Makefile | 8 +- pkg/api/authn_test.go | 25 +- pkg/api/controller_test.go | 440 +++++++++++++----- pkg/cli/client/client.go | 63 ++- pkg/cli/client/client_test.go | 170 ++++++- pkg/cli/client/cve_cmd_internal_test.go | 85 ++-- pkg/cli/client/discover.go | 6 +- pkg/cli/client/elevated_test.go | 116 +++-- pkg/cli/client/gql_queries_test.go | 11 +- pkg/cli/client/image_cmd_internal_test.go | 102 ++-- pkg/cli/client/search_cmd_internal_test.go | 12 +- .../client/search_functions_internal_test.go | 105 +++-- pkg/cli/client/server_info_cmd.go | 6 +- pkg/cli/client/server_info_cmd_test.go | 12 +- pkg/cli/client/service.go | 67 +-- pkg/cli/client/utils.go | 3 +- pkg/cli/client/utils_internal_test.go | 28 +- pkg/common/http_client_test.go | 69 ++- pkg/extensions/extensions_test.go | 47 +- pkg/extensions/sync/sync_test.go | 276 +++++++---- pkg/test/tls/tls.go | 227 +++++++-- pkg/test/tls/tls_test.go | 266 +++++++++++ 22 files changed, 1590 insertions(+), 554 deletions(-) diff --git a/Makefile b/Makefile index d6ea38d7..17f4fe86 100644 --- a/Makefile +++ b/Makefile @@ -211,18 +211,18 @@ test-prereq: check-skopeo $(TESTDATA) $(ORAS) .PHONY: test-extended test-extended: $(if $(findstring ui,$(BUILD_LABELS)), ui) -test-extended: test-prereq +test-extended: testdata-images env GOEXPERIMENT=jsonv2 go test -failfast $(GO_CMD_TAGS) -trimpath -race -timeout 20m -cover -coverpkg ./... -coverprofile=coverage-extended.txt -covermode=atomic ./... rm -rf /tmp/getter*; rm -rf /tmp/trivy* .PHONY: test-minimal -test-minimal: test-prereq +test-minimal: testdata-images env GOEXPERIMENT=jsonv2 go test -failfast -trimpath -race -cover -coverpkg ./... -coverprofile=coverage-minimal.txt -covermode=atomic ./... rm -rf /tmp/getter*; rm -rf /tmp/trivy* .PHONY: test-devmode test-devmode: $(if $(findstring ui,$(BUILD_LABELS)), ui) -test-devmode: testdata-certs +test-devmode: env GOEXPERIMENT=jsonv2 go test -failfast -tags dev,$(BUILD_LABELS) -trimpath -race -timeout 15m -cover -coverpkg ./... -coverprofile=coverage-dev-extended.txt -covermode=atomic ./pkg/test/... ./pkg/api/... ./pkg/storage/... ./pkg/extensions/sync/... -run ^TestInject rm -rf /tmp/getter*; rm -rf /tmp/trivy* env GOEXPERIMENT=jsonv2 go test -failfast -tags dev -trimpath -race -cover -coverpkg ./... -coverprofile=coverage-dev-minimal.txt -covermode=atomic ./pkg/test/... ./pkg/storage/... ./pkg/extensions/sync/... -run ^TestInject @@ -235,7 +235,7 @@ test: test-extended test-minimal test-devmode .PHONY: privileged-test privileged-test: $(if $(findstring ui,$(BUILD_LABELS)), ui) -privileged-test: testdata-certs +privileged-test: env GOEXPERIMENT=jsonv2 go test -failfast -tags needprivileges,$(BUILD_LABELS) -trimpath -race -timeout 15m -cover -coverpkg ./... -coverprofile=coverage-needprivileges.txt -covermode=atomic ./pkg/storage/local/... ./pkg/cli/client/... -run ^TestElevatedPrivileges .PHONY: testdata-certs diff --git a/pkg/api/authn_test.go b/pkg/api/authn_test.go index 20e8c093..c58ac723 100644 --- a/pkg/api/authn_test.go +++ b/pkg/api/authn_test.go @@ -2102,11 +2102,22 @@ func TestCookiestoreCleanup(t *testing.T) { func TestCookieSecureFlag(t *testing.T) { Convey("Test cookie Secure flag based on configuration", t, func() { - const ( - serverCertPath = "../../test/data/server.cert" - serverKeyPath = "../../test/data/server.key" - caCertPath = "../../test/data/ca.crt" - ) + // Generate certificates dynamically for the test + tempDir := t.TempDir() + caCert, caKey, err := tlsutils.GenerateCACert() + So(err, ShouldBeNil) + + caCertPath := path.Join(tempDir, "ca.crt") + err = os.WriteFile(caCertPath, caCert, 0o600) + So(err, ShouldBeNil) + + serverCertPath := path.Join(tempDir, "server.crt") + serverKeyPath := path.Join(tempDir, "server.key") + opts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + } + err = tlsutils.GenerateServerCertToFile(caCert, caKey, serverCertPath, serverKeyPath, opts) + So(err, ShouldBeNil) mockOIDCServer, err := authutils.MockOIDCRun() So(err, ShouldBeNil) @@ -2116,11 +2127,12 @@ func TestCookieSecureFlag(t *testing.T) { So(err, ShouldBeNil) }() + mockOIDCConfig := mockOIDCServer.Config() + username, _ := test.GenerateRandomString() password, _ := test.GenerateRandomString() htpasswdPath := test.MakeHtpasswdFileFromString(t, test.GetBcryptCredString(username, password)) - mockOIDCConfig := mockOIDCServer.Config() defaultVal := true Convey("Test with TLS configured - cookies should be Secure=true", func() { @@ -2155,7 +2167,6 @@ func TestCookieSecureFlag(t *testing.T) { ctlr.Config.Storage.RootDirectory = t.TempDir() cm := test.NewControllerManager(ctlr) - cm.StartServer() defer cm.StopServer() diff --git a/pkg/api/controller_test.go b/pkg/api/controller_test.go index 7d233caf..b0475c82 100644 --- a/pkg/api/controller_test.go +++ b/pkg/api/controller_test.go @@ -71,20 +71,10 @@ import ( ociutils "zotregistry.dev/zot/v2/pkg/test/oci-utils" "zotregistry.dev/zot/v2/pkg/test/signature" tskip "zotregistry.dev/zot/v2/pkg/test/skip" + tlsutils "zotregistry.dev/zot/v2/pkg/test/tls" ) const ( - ServerCert = "../../test/data/server.cert" - ServerKey = "../../test/data/server.key" - ServerPublicKey = "../../test/data/server-public.key" - ServerPublicKeyPKCS1 = "../../test/data/server-public-pkcs1.key" - CACert = "../../test/data/ca.crt" - ServerCertECDSA = "../../test/data/server-ecdsa.cert" - ServerKeyECDSA = "../../test/data/server-ecdsa.key" - ServerPublicKeyECDSA = "../../test/data/server-public-ecdsa.key" - ServerCertED25519 = "../../test/data/server-ed25519.cert" - ServerKeyED25519 = "../../test/data/server-ed25519.key" - ServerPublicKeyED25519 = "../../test/data/server-public-ed25519.key" UnauthorizedNamespace = "fortknox/notallowed" AuthorizationNamespace = "authz/image" LDAPAddress = "127.0.0.1" @@ -100,6 +90,149 @@ var ( LDAPUserAttr = "uid" //nolint: gochecknoglobals ) +// setupTestCerts generates CA, server, and client certificates for testing. +// Returns paths to certificate files and PEM data for CA cert. +func setupTestCerts(t *testing.T) ( + string, string, string, string, string, []byte, +) { + t.Helper() + tempDir := t.TempDir() + + // Generate CA certificate (10 years validity, matching gen_certs.sh) + caOpts := &tlsutils.CertificateOptions{ + CommonName: "*", + NotAfter: time.Now().AddDate(10, 0, 0), + } + caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts) + if err != nil { + t.Fatalf("Failed to generate CA cert: %v", err) + } + + caCertPath := path.Join(tempDir, "ca.crt") + caKeyPath := path.Join(tempDir, "ca.key") + err = os.WriteFile(caCertPath, caCertPEM, 0o600) + if err != nil { + t.Fatalf("Failed to write CA cert: %v", err) + } + _ = os.WriteFile(caKeyPath, caKeyPEM, 0o600) + + // Generate server certificate + serverCertPath := path.Join(tempDir, "server.cert") + serverKeyPath := path.Join(tempDir, "server.key") + serverOpts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "*", + OrganizationalUnit: "TestServer", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts) + if err != nil { + t.Fatalf("Failed to generate server cert: %v", err) + } + + // Generate client certificate (10 years validity, matching gen_certs.sh) + clientCertPath := path.Join(tempDir, "client.cert") + clientKeyPath := path.Join(tempDir, "client.key") + clientOpts := &tlsutils.CertificateOptions{ + CommonName: "testclient", + OrganizationalUnit: "TestClient", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateClientCertToFile(caCertPEM, caKeyPEM, clientCertPath, clientKeyPath, clientOpts) + if err != nil { + t.Fatalf("Failed to generate client cert: %v", err) + } + + return caCertPath, serverCertPath, serverKeyPath, clientCertPath, clientKeyPath, caCertPEM +} + +// setupBearerAuthServerCerts generates CA and server certificates for bearer auth server testing +// with a specific key type. Returns paths to server certificate, key, and public key files. +func setupBearerAuthServerCerts(t *testing.T, keyType tlsutils.KeyType) ( + string, string, string, string, +) { + t.Helper() + tempDir := t.TempDir() + + // Generate CA certificate with specified key type + caOpts := &tlsutils.CertificateOptions{ + CommonName: "*", + NotAfter: time.Now().AddDate(10, 0, 0), + KeyType: keyType, + } + caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts) + if err != nil { + t.Fatalf("Failed to generate CA cert: %v", err) + } + + // Generate server certificate with specified key type + serverCertPath := path.Join(tempDir, "server.cert") + serverKeyPath := path.Join(tempDir, "server.key") + serverOpts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "*", + OrganizationalUnit: "TestServer", + NotAfter: time.Now().AddDate(10, 0, 0), + KeyType: keyType, + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts) + if err != nil { + t.Fatalf("Failed to generate server cert: %v", err) + } + + // Extract public keys from server certificate/key + serverCertBytes, err := os.ReadFile(serverCertPath) + if err != nil { + t.Fatalf("Failed to read server cert: %v", err) + } + + serverPublicKeyPath := path.Join(tempDir, "server-public.key") + + var serverPublicKeyPKCS1Path string + + if keyType == tlsutils.KeyTypeRSA { + // For RSA, also generate PKCS1 format public key + serverKeyBytes, err := os.ReadFile(serverKeyPath) + if err != nil { + t.Fatalf("Failed to read server key: %v", err) + } + + // Extract PKIX format public key (from cert) + publicKeyPKIX, err := tlsutils.ExtractPublicKeyFromCert(serverCertBytes) + if err != nil { + t.Fatalf("Failed to extract public key from cert: %v", err) + } + err = os.WriteFile(serverPublicKeyPath, publicKeyPKIX, 0o600) + if err != nil { + t.Fatalf("Failed to write server public key: %v", err) + } + + // Extract PKCS1 format public key (from private key) + serverPublicKeyPKCS1Path = path.Join(tempDir, "server-public-pkcs1.key") + publicKeyPKCS1, err := tlsutils.ExtractRSAPublicKeyPKCS1(serverKeyBytes) + if err != nil { + t.Fatalf("Failed to extract PKCS1 public key: %v", err) + } + err = os.WriteFile(serverPublicKeyPKCS1Path, publicKeyPKCS1, 0o600) + if err != nil { + t.Fatalf("Failed to write server PKCS1 public key: %v", err) + } + } else { + // For ECDSA and ED25519, extract PKIX format public key + publicKeyPKIX, err := tlsutils.ExtractPublicKeyFromCert(serverCertBytes) + if err != nil { + t.Fatalf("Failed to extract public key from cert: %v", err) + } + err = os.WriteFile(serverPublicKeyPath, publicKeyPKIX, 0o600) + if err != nil { + t.Fatalf("Failed to write server public key: %v", err) + } + serverPublicKeyPKCS1Path = "" // Not applicable for non-RSA keys + } + + return serverCertPath, serverKeyPath, serverPublicKeyPath, serverPublicKeyPKCS1Path +} + func TestNew(t *testing.T) { Convey("Make a new controller", t, func() { conf := config.New() @@ -1347,11 +1480,10 @@ func TestScaleOutRequestProxy(t *testing.T) { clusterMembers[idx] = "127.0.0.1:" + port } - caCert, err := os.ReadFile(CACert) - So(err, ShouldBeNil) + caCertPath, serverCertPath, serverKeyPath, _, _, caCertPEM := setupTestCerts(t) caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) username, _ := test.GenerateRandomString() password, _ := test.GenerateRandomString() @@ -1365,8 +1497,8 @@ func TestScaleOutRequestProxy(t *testing.T) { conf := config.New() conf.HTTP.Port = port conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, + Cert: serverCertPath, + Key: serverKeyPath, } conf.HTTP.Auth = &config.AuthConfig{ HTPasswd: config.AuthHTPasswd{ @@ -1377,7 +1509,7 @@ func TestScaleOutRequestProxy(t *testing.T) { Members: clusterMembers, HashKey: "loremipsumdolors", TLS: &config.TLSConfig{ - CACert: CACert, + CACert: caCertPath, }, } @@ -1445,12 +1577,14 @@ func TestScaleOutRequestProxy(t *testing.T) { clusterMembers[idx] = "127.0.0.1:" + port } + _, serverCertPath, serverKeyPath, _, _, caCertPEM := setupTestCerts(t) + for _, port := range ports { conf := config.New() conf.HTTP.Port = port conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, + Cert: serverCertPath, + Key: serverKeyPath, } conf.Cluster = &config.ClusterConfig{ Members: clusterMembers, @@ -1469,11 +1603,8 @@ func TestScaleOutRequestProxy(t *testing.T) { }(cm) } - caCert, err := os.ReadFile(CACert) - So(err, ShouldBeNil) - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}) defer func() { resty.SetTLSClientConfig(nil) }() @@ -1509,18 +1640,21 @@ func TestScaleOutRequestProxy(t *testing.T) { clusterMembers[idx] = "127.0.0.1:" + port } + // Generate certificates dynamically for the test + caCertPath, serverCertPath, serverKeyPath, _, _, caCertPEM := setupTestCerts(t) + for _, port := range ports { conf := config.New() conf.HTTP.Port = port conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, + Cert: serverCertPath, + Key: serverKeyPath, } conf.Cluster = &config.ClusterConfig{ Members: clusterMembers, HashKey: "loremipsumdolors", TLS: &config.TLSConfig{ - CACert: CACert, + CACert: caCertPath, Cert: "/tmp/does-not-exist.crt", }, } @@ -1535,11 +1669,8 @@ func TestScaleOutRequestProxy(t *testing.T) { }(cm) } - caCert, err := os.ReadFile(CACert) - So(err, ShouldBeNil) - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}) defer func() { resty.SetTLSClientConfig(nil) }() @@ -1575,19 +1706,21 @@ func TestScaleOutRequestProxy(t *testing.T) { clusterMembers[idx] = "127.0.0.1:" + port } + caCertPath, serverCertPath, serverKeyPath, _, _, caCertPEM := setupTestCerts(t) + for _, port := range ports { conf := config.New() conf.HTTP.Port = port conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, + Cert: serverCertPath, + Key: serverKeyPath, } conf.Cluster = &config.ClusterConfig{ Members: clusterMembers, HashKey: "loremipsumdolors", TLS: &config.TLSConfig{ - CACert: CACert, - Cert: ServerCert, + CACert: caCertPath, + Cert: serverCertPath, Key: "/tmp/does-not-exist.crt", }, } @@ -1602,11 +1735,8 @@ func TestScaleOutRequestProxy(t *testing.T) { }(cm) } - caCert, err := os.ReadFile(CACert) - So(err, ShouldBeNil) - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}) defer func() { resty.SetTLSClientConfig(nil) }() @@ -2056,11 +2186,10 @@ func TestMultipleInstance(t *testing.T) { func TestTLSWithBasicAuth(t *testing.T) { Convey("Make a new controller", t, func() { - caCert, err := os.ReadFile(CACert) - So(err, ShouldBeNil) + _, serverCertPath, serverKeyPath, _, _, caCertPEM := setupTestCerts(t) caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) username, seedUser := test.GenerateRandomString() password, seedPass := test.GenerateRandomString() @@ -2078,8 +2207,8 @@ func TestTLSWithBasicAuth(t *testing.T) { conf := config.New() conf.HTTP.Port = port conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, + Cert: serverCertPath, + Key: serverKeyPath, } conf.HTTP.Auth = &config.AuthConfig{ HTPasswd: config.AuthHTPasswd{ @@ -2125,11 +2254,11 @@ func TestTLSWithBasicAuth(t *testing.T) { func TestTLSWithBasicAuthAllowReadAccess(t *testing.T) { Convey("Make a new controller", t, func() { - caCert, err := os.ReadFile(CACert) - So(err, ShouldBeNil) + // Generate certificates dynamically for the test + _, serverCertPath, serverKeyPath, _, _, caCertPEM := setupTestCerts(t) caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) username, seedUser := test.GenerateRandomString() password, seedPass := test.GenerateRandomString() @@ -2152,8 +2281,8 @@ func TestTLSWithBasicAuthAllowReadAccess(t *testing.T) { }, } conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, + Cert: serverCertPath, + Key: serverKeyPath, } conf.HTTP.AccessControl = &config.AccessControlConfig{ @@ -2201,11 +2330,10 @@ func TestTLSWithBasicAuthAllowReadAccess(t *testing.T) { func TestMutualTLSAuthWithUserPermissions(t *testing.T) { Convey("Make a new controller", t, func() { - caCert, err := os.ReadFile(CACert) - So(err, ShouldBeNil) + caCertPath, serverCertPath, serverKeyPath, clientCertPath, clientKeyPath, caCertPEM := setupTestCerts(t) caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) port := test.GetFreePort() baseURL := test.GetBaseURL(port) @@ -2219,9 +2347,9 @@ func TestMutualTLSAuthWithUserPermissions(t *testing.T) { conf.HTTP.Port = port conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, - CACert: CACert, + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } conf.HTTP.AccessControl = &config.AccessControlConfig{ @@ -2229,7 +2357,7 @@ func TestMutualTLSAuthWithUserPermissions(t *testing.T) { test.AuthorizationAllRepos: config.PolicyGroup{ Policies: []config.Policy{ { - Users: []string{"*"}, + Users: []string{"testclient"}, Actions: []string{"read"}, }, }, @@ -2252,7 +2380,7 @@ func TestMutualTLSAuthWithUserPermissions(t *testing.T) { repoPolicy := conf.HTTP.AccessControl.Repositories[test.AuthorizationAllRepos] // setup TLS mutual auth - cert, err := tls.LoadX509KeyPair("../../test/data/client.cert", "../../test/data/client.key") + cert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) So(err, ShouldBeNil) // Use separate resty client with certificates, because we cannot perform cleanup with resty.SetCertificates() @@ -2361,6 +2489,8 @@ func TestAuthnErrors(t *testing.T) { port := test.GetFreePort() conf := config.New() conf.HTTP.Port = port + // Generate certificates dynamically for the test + caCertPath, _, _, _, _, _ := setupTestCerts(t) conf.HTTP.Auth.LDAP = (&config.LDAPConfig{ Insecure: true, @@ -2368,7 +2498,7 @@ func TestAuthnErrors(t *testing.T) { Port: 9000, BaseDN: LDAPBaseDN, UserAttribute: "uid", - CACert: CACert, + CACert: caCertPath, }).SetBindDN(LDAPBindDN).SetBindPassword(LDAPBindPassword) ctlr := makeController(conf, t.TempDir()) @@ -2488,11 +2618,42 @@ func TestAuthnErrors(t *testing.T) { func TestMutualTLSAuthWithoutCN(t *testing.T) { Convey("Make a new controller", t, func() { - caCert, err := os.ReadFile("../../test/data/noidentity/ca.crt") + // Generate certificates without CommonName for client + tempDir := t.TempDir() + caOpts := &tlsutils.CertificateOptions{ + CommonName: "*", + NotAfter: time.Now().AddDate(10, 0, 0), + } + caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts) + So(err, ShouldBeNil) + + caCertPath := path.Join(tempDir, "ca.crt") + err = os.WriteFile(caCertPath, caCertPEM, 0o600) + So(err, ShouldBeNil) + + serverCertPath := path.Join(tempDir, "server.cert") + serverKeyPath := path.Join(tempDir, "server.key") + serverOpts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "*", + OrganizationalUnit: "TestServer", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts) + So(err, ShouldBeNil) + + // Generate client certificate without CommonName (10 years validity, matching gen_certs.sh) + clientCertPath := path.Join(tempDir, "client.cert") + clientKeyPath := path.Join(tempDir, "client.key") + clientOpts := &tlsutils.CertificateOptions{ + // CommonName intentionally not set + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateClientCertToFile(caCertPEM, caKeyPEM, clientCertPath, clientKeyPath, clientOpts) So(err, ShouldBeNil) caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) port := test.GetFreePort() secureBaseURL := test.GetSecureBaseURL(port) @@ -2505,9 +2666,9 @@ func TestMutualTLSAuthWithoutCN(t *testing.T) { conf.HTTP.Port = port conf.HTTP.TLS = &config.TLSConfig{ - Cert: "../../test/data/noidentity/server.cert", - Key: "../../test/data/noidentity/server.key", - CACert: "../../test/data/noidentity/ca.crt", + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } conf.HTTP.AccessControl = &config.AccessControlConfig{ @@ -2531,7 +2692,7 @@ func TestMutualTLSAuthWithoutCN(t *testing.T) { defer cm.StopServer() // setup TLS mutual auth - cert, err := tls.LoadX509KeyPair("../../test/data/noidentity/client.cert", "../../test/data/noidentity/client.key") + cert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) So(err, ShouldBeNil) // Use separate resty client with certificates, because we cannot perform cleanup with resty.SetCertificates() @@ -2549,11 +2710,10 @@ func TestMutualTLSAuthWithoutCN(t *testing.T) { func TestTLSMutualAuth(t *testing.T) { Convey("Make a new controller", t, func() { - caCert, err := os.ReadFile(CACert) - So(err, ShouldBeNil) + caCertPath, serverCertPath, serverKeyPath, clientCertPath, clientKeyPath, caCertPEM := setupTestCerts(t) caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) port := test.GetFreePort() baseURL := test.GetBaseURL(port) @@ -2562,9 +2722,9 @@ func TestTLSMutualAuth(t *testing.T) { conf := config.New() conf.HTTP.Port = port conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, - CACert: CACert, + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } ctlr := makeController(conf, t.TempDir()) @@ -2622,7 +2782,7 @@ func TestTLSMutualAuth(t *testing.T) { So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) // setup TLS mutual auth - cert, err := tls.LoadX509KeyPair("../../test/data/client.cert", "../../test/data/client.key") + cert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) So(err, ShouldBeNil) client = resty.New().SetTLSClientConfig(&tls.Config{ @@ -2651,18 +2811,19 @@ func TestTLSMutualAuth(t *testing.T) { func TestTSLFailedReadingOfCACert(t *testing.T) { Convey("no permissions", t, func() { + caCertPath, serverCertPath, serverKeyPath, _, _, _ := setupTestCerts(t) port := test.GetFreePort() conf := config.New() conf.HTTP.Port = port conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, - CACert: CACert, + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } - err := os.Chmod(CACert, 0o000) + err := os.Chmod(caCertPath, 0o000) defer func() { - err := os.Chmod(CACert, 0o644) + err := os.Chmod(caCertPath, 0o644) So(err, ShouldBeNil) }() So(err, ShouldBeNil) @@ -2701,12 +2862,13 @@ func TestTSLFailedReadingOfCACert(t *testing.T) { err := os.WriteFile(badCACert, []byte(""), 0o600) So(err, ShouldBeNil) + _, serverCertPath, serverKeyPath, _, _, _ := setupTestCerts(t) port := test.GetFreePort() conf := config.New() conf.HTTP.Port = port conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, + Cert: serverCertPath, + Key: serverKeyPath, CACert: badCACert, } @@ -2742,11 +2904,10 @@ func TestTSLFailedReadingOfCACert(t *testing.T) { func TestTLSMutualAuthAllowReadAccess(t *testing.T) { Convey("Make a new controller", t, func() { - caCert, err := os.ReadFile(CACert) - So(err, ShouldBeNil) + caCertPath, serverCertPath, serverKeyPath, clientCertPath, clientKeyPath, caCertPEM := setupTestCerts(t) caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) port := test.GetFreePort() baseURL := test.GetBaseURL(port) @@ -2761,9 +2922,9 @@ func TestTLSMutualAuthAllowReadAccess(t *testing.T) { conf := config.New() conf.HTTP.Port = port conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, - CACert: CACert, + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } conf.HTTP.AccessControl = &config.AccessControlConfig{ @@ -2808,7 +2969,7 @@ func TestTLSMutualAuthAllowReadAccess(t *testing.T) { So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) // setup TLS mutual auth - cert, err := tls.LoadX509KeyPair("../../test/data/client.cert", "../../test/data/client.key") + cert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) So(err, ShouldBeNil) // Use separate resty client with certificates, because we cannot perform cleanup with resty.SetCertificates() @@ -2844,11 +3005,10 @@ func TestTLSMutualAuthAllowReadAccess(t *testing.T) { func TestTLSMutualAndBasicAuth(t *testing.T) { Convey("Make a new controller", t, func() { - caCert, err := os.ReadFile(CACert) - So(err, ShouldBeNil) + caCertPath, serverCertPath, serverKeyPath, clientCertPath, clientKeyPath, caCertPEM := setupTestCerts(t) caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) username, seedUser := test.GenerateRandomString() password, seedPass := test.GenerateRandomString() @@ -2871,9 +3031,9 @@ func TestTLSMutualAndBasicAuth(t *testing.T) { }, } conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, - CACert: CACert, + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } ctlr := makeController(conf, t.TempDir()) @@ -2903,7 +3063,7 @@ func TestTLSMutualAndBasicAuth(t *testing.T) { So(resp.StatusCode(), ShouldEqual, http.StatusBadRequest) // setup TLS mutual auth - cert, err := tls.LoadX509KeyPair("../../test/data/client.cert", "../../test/data/client.key") + cert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) So(err, ShouldBeNil) // Use separate resty client with certificates, because we cannot perform cleanup with resty.SetCertificates() @@ -2932,11 +3092,10 @@ func TestTLSMutualAndBasicAuth(t *testing.T) { func TestTLSMutualAndBasicAuthAllowReadAccess(t *testing.T) { Convey("Make a new controller", t, func() { - caCert, err := os.ReadFile(CACert) - So(err, ShouldBeNil) + caCertPath, serverCertPath, serverKeyPath, clientCertPath, clientKeyPath, caCertPEM := setupTestCerts(t) caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) username, seedUser := test.GenerateRandomString() password, seedPass := test.GenerateRandomString() @@ -2959,9 +3118,9 @@ func TestTLSMutualAndBasicAuthAllowReadAccess(t *testing.T) { }, } conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, - CACert: CACert, + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } conf.HTTP.AccessControl = &config.AccessControlConfig{ @@ -2999,7 +3158,7 @@ func TestTLSMutualAndBasicAuthAllowReadAccess(t *testing.T) { So(resp.StatusCode(), ShouldEqual, http.StatusBadRequest) // setup TLS mutual auth - cert, err := tls.LoadX509KeyPair("../../test/data/client.cert", "../../test/data/client.key") + cert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) So(err, ShouldBeNil) // Use separate resty client with certificates, because we cannot perform cleanup with resty.SetCertificates() @@ -3965,58 +4124,79 @@ func TestLDAPClient(t *testing.T) { func TestBearerAuthMultipleAlgorithms(t *testing.T) { testCases := []struct { - name string - key string - cert string - alg string + name string + keyType tlsutils.KeyType + certUse string // "certificate" or "publickey" or "publickey-pkcs1" (RSA only) + alg string }{ { "RSA signing key using certificate", - ServerKey, - ServerCert, + tlsutils.KeyTypeRSA, + "certificate", "RS256", }, { "RSA signing key using public key", - ServerKey, - ServerPublicKey, + tlsutils.KeyTypeRSA, + "publickey", "RS256", }, { "RSA signing key using public key in PKCS1 format", - ServerKey, - ServerPublicKeyPKCS1, + tlsutils.KeyTypeRSA, + "publickey-pkcs1", "RS256", }, { "ECDSA signing key using certificate", - ServerKeyECDSA, - ServerCertECDSA, + tlsutils.KeyTypeECDSA, + "certificate", "ES256", }, { "ECDSA signing key using public key", - ServerKeyECDSA, - ServerPublicKeyECDSA, + tlsutils.KeyTypeECDSA, + "publickey", "ES256", }, { "ED25519 signing key using certificate", - ServerKeyED25519, - ServerCertED25519, + tlsutils.KeyTypeED25519, + "certificate", "EdDSA", }, { "ED25519 signing key using public key", - ServerKeyED25519, - ServerPublicKeyED25519, + tlsutils.KeyTypeED25519, + "publickey", "EdDSA", }, } for _, testCase := range testCases { Convey("Make a new controller with "+testCase.name, t, func() { - authTestServer := authutils.MakeAuthTestServer(testCase.key, testCase.alg, UnauthorizedNamespace) + // Generate certificates dynamically for the test + serverCertPath, serverKeyPath, serverPublicKeyPath, serverPublicKeyPKCS1Path := setupBearerAuthServerCerts( + t, testCase.keyType) + + // Determine which cert/key to use based on test case + var keyPath, certPath string + switch testCase.certUse { + case "certificate": + certPath = serverCertPath + case "publickey": + certPath = serverPublicKeyPath + case "publickey-pkcs1": + if testCase.keyType != tlsutils.KeyTypeRSA { + t.Fatalf("PKCS1 format only supported for RSA keys") + } + certPath = serverPublicKeyPKCS1Path + default: + t.Fatalf("Unknown cert use: %s", testCase.certUse) + } + keyPath = serverKeyPath + + authTestServer := authutils.MakeAuthTestServer(keyPath, testCase.alg, UnauthorizedNamespace) defer authTestServer.Close() port := test.GetFreePort() @@ -4030,7 +4210,7 @@ func TestBearerAuthMultipleAlgorithms(t *testing.T) { conf.HTTP.Auth = &config.AuthConfig{ Bearer: &config.BearerConfig{ - Cert: testCase.cert, + Cert: certPath, Realm: authTestServer.URL + "/auth/token", Service: aurl.Host, }, @@ -4087,11 +4267,14 @@ func TestBearerAuth(t *testing.T) { for _, testCase := range testCases { Convey("Make a new controller with "+testCase.name, t, func() { + // Generate certificates dynamically for the test + serverCertPath, serverKeyPath, _, _ := setupBearerAuthServerCerts(t, tlsutils.KeyTypeRSA) + var authTestServer *httptest.Server if testCase.useLegacyAuthTestServer { - authTestServer = authutils.MakeAuthTestServerLegacy(ServerKey, UnauthorizedNamespace) + authTestServer = authutils.MakeAuthTestServerLegacy(serverKeyPath, UnauthorizedNamespace) } else { - authTestServer = authutils.MakeAuthTestServer(ServerKey, "RS256", UnauthorizedNamespace) + authTestServer = authutils.MakeAuthTestServer(serverKeyPath, "RS256", UnauthorizedNamespace) } defer authTestServer.Close() @@ -4106,7 +4289,7 @@ func TestBearerAuth(t *testing.T) { conf.HTTP.Auth = &config.AuthConfig{ Bearer: &config.BearerConfig{ - Cert: ServerCert, + Cert: serverCertPath, Realm: authTestServer.URL + "/auth/token", Service: aurl.Host, }, @@ -4315,11 +4498,14 @@ func TestBearerAuthWithAllowReadAccess(t *testing.T) { for _, testCase := range testCases { Convey("Make a new controller with"+testCase.name, t, func() { + // Generate certificates dynamically for the test + serverCertPath, serverKeyPath, _, _ := setupBearerAuthServerCerts(t, tlsutils.KeyTypeRSA) + var authTestServer *httptest.Server if testCase.useLegacyAuthTestServer { - authTestServer = authutils.MakeAuthTestServerLegacy(ServerKey, UnauthorizedNamespace) + authTestServer = authutils.MakeAuthTestServerLegacy(serverKeyPath, UnauthorizedNamespace) } else { - authTestServer = authutils.MakeAuthTestServer(ServerKey, "RS256", UnauthorizedNamespace) + authTestServer = authutils.MakeAuthTestServer(serverKeyPath, "RS256", UnauthorizedNamespace) } defer authTestServer.Close() @@ -4334,7 +4520,7 @@ func TestBearerAuthWithAllowReadAccess(t *testing.T) { conf.HTTP.Auth = &config.AuthConfig{ Bearer: &config.BearerConfig{ - Cert: ServerCert, + Cert: serverCertPath, Realm: authTestServer.URL + "/auth/token", Service: aurl.Host, }, @@ -4545,9 +4731,11 @@ func TestNewRelyingPartyOIDC(t *testing.T) { }) Convey("https callback", func() { + // Generate certificates dynamically for the test + _, serverCertPath, serverKeyPath, _, _, _ := setupTestCerts(t) conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, + Cert: serverCertPath, + Key: serverKeyPath, } rp := api.NewRelyingPartyOIDC(ctx, conf, "oidc", nil, nil, log.NewTestLogger()) diff --git a/pkg/cli/client/client.go b/pkg/cli/client/client.go index e5528f83..df317f56 100644 --- a/pkg/cli/client/client.go +++ b/pkg/cli/client/client.go @@ -21,12 +21,20 @@ import ( "zotregistry.dev/zot/v2/pkg/common" ) -var ( - httpClientsMap = make(map[string]*http.Client) //nolint: gochecknoglobals - httpClientLock sync.Mutex //nolint: gochecknoglobals -) +// HTTPClient manages HTTP clients with TLS support and caching. +type HTTPClient struct { + clients map[string]*http.Client + mu sync.Mutex +} -func makeGETRequest(ctx context.Context, url, username, password string, +// NewHTTPClient creates a new HTTPClient instance. +func NewHTTPClient() *HTTPClient { + return &HTTPClient{ + clients: make(map[string]*http.Client), + } +} + +func (c *HTTPClient) makeGETRequest(ctx context.Context, url, username, password string, verifyTLS bool, debug bool, resultsPtr any, configWriter io.Writer, ) (http.Header, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) @@ -36,10 +44,10 @@ func makeGETRequest(ctx context.Context, url, username, password string, req.SetBasicAuth(username, password) - return doHTTPRequest(req, verifyTLS, debug, resultsPtr, configWriter) + return c.doHTTPRequest(req, verifyTLS, debug, resultsPtr, configWriter) } -func makeHEADRequest(ctx context.Context, url, username, password string, verifyTLS bool, +func (c *HTTPClient) makeHEADRequest(ctx context.Context, url, username, password string, verifyTLS bool, debug bool, ) (http.Header, error) { req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) @@ -49,10 +57,10 @@ func makeHEADRequest(ctx context.Context, url, username, password string, verify req.SetBasicAuth(username, password) - return doHTTPRequest(req, verifyTLS, debug, nil, io.Discard) + return c.doHTTPRequest(req, verifyTLS, debug, nil, io.Discard) } -func makeGraphQLRequest(ctx context.Context, url, query, username, +func (c *HTTPClient) makeGraphQLRequest(ctx context.Context, url, query, username, password string, verifyTLS bool, debug bool, resultsPtr any, configWriter io.Writer, ) error { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, bytes.NewBufferString(query)) @@ -68,7 +76,7 @@ func makeGraphQLRequest(ctx context.Context, url, query, username, req.SetBasicAuth(username, password) req.Header.Add("Content-Type", "application/json") - _, err = doHTTPRequest(req, verifyTLS, debug, resultsPtr, configWriter) + _, err = c.doHTTPRequest(req, verifyTLS, debug, resultsPtr, configWriter) if err != nil { return err } @@ -76,7 +84,7 @@ func makeGraphQLRequest(ctx context.Context, url, query, username, return nil } -func doHTTPRequest(req *http.Request, verifyTLS bool, debug bool, +func (c *HTTPClient) doHTTPRequest(req *http.Request, verifyTLS bool, debug bool, resultsPtr any, configWriter io.Writer, ) (http.Header, error) { var httpClient *http.Client @@ -91,9 +99,9 @@ func doHTTPRequest(req *http.Request, verifyTLS bool, debug bool, enableTLS = true } - httpClientLock.Lock() + c.mu.Lock() - if httpClientsMap[host] == nil { + if c.clients[host] == nil { httpClient, err = common.CreateHTTPClient(&common.HTTPClientOptions{ TLSEnabled: enableTLS, VerifyTLS: verifyTLS, @@ -101,15 +109,17 @@ func doHTTPRequest(req *http.Request, verifyTLS bool, debug bool, CertOptions: common.HTTPClientCertOptions{}, }) if err != nil { + c.mu.Unlock() + return nil, err } - httpClientsMap[host] = httpClient + c.clients[host] = httpClient } else { - httpClient = httpClientsMap[host] + httpClient = c.clients[host] } - httpClientLock.Unlock() + c.mu.Unlock() if debug { fmt.Fprintln(configWriter, "[debug] ", req.Method, " ", req.URL, "[request header] ", req.Header) @@ -225,7 +235,8 @@ func (p *requestsPool) doJob(ctx context.Context, job *httpJob) { defer p.wtgrp.Done() // Check manifest media type - header, err := makeHEADRequest(ctx, job.url, job.username, job.password, job.config.VerifyTLS, + httpClient := job.config.SearchService.getHTTPClient() + header, err := httpClient.makeHEADRequest(ctx, job.url, job.username, job.password, job.config.VerifyTLS, job.config.Debug) if err != nil { if common.IsContextDone(ctx) { @@ -300,7 +311,8 @@ func (p *requestsPool) doJob(ctx context.Context, job *httpJob) { func fetchImageIndexStruct(ctx context.Context, job *httpJob) (*imageStruct, error) { var indexContent ispec.Index - header, err := makeGETRequest(ctx, job.url, job.username, job.password, + httpClient := job.config.SearchService.getHTTPClient() + header, err := httpClient.makeGETRequest(ctx, job.url, job.username, job.password, job.config.VerifyTLS, job.config.Debug, &indexContent, job.config.ResultWriter) if err != nil { if common.IsContextDone(ctx) { @@ -388,10 +400,11 @@ func fetchManifestStruct(ctx context.Context, repo, manifestReference string, se ) (common.ManifestSummary, error) { manifestResp := ispec.Manifest{} + httpClient := searchConf.SearchService.getHTTPClient() URL := fmt.Sprintf("%s/v2/%s/manifests/%s", searchConf.ServURL, repo, manifestReference) - header, err := makeGETRequest(ctx, URL, username, password, + header, err := httpClient.makeGETRequest(ctx, URL, username, password, searchConf.VerifyTLS, searchConf.Debug, &manifestResp, searchConf.ResultWriter) if err != nil { if common.IsContextDone(ctx) { @@ -478,10 +491,11 @@ func fetchConfig(ctx context.Context, repo, configDigest string, searchConf Sear ) (ispec.Image, error) { configContent := ispec.Image{} + httpClient := searchConf.SearchService.getHTTPClient() URL := fmt.Sprintf("%s/v2/%s/blobs/%s", searchConf.ServURL, repo, configDigest) - _, err := makeGETRequest(ctx, URL, username, password, + _, err := httpClient.makeGETRequest(ctx, URL, username, password, searchConf.VerifyTLS, searchConf.Debug, &configContent, searchConf.ResultWriter) if err != nil { if common.IsContextDone(ctx) { @@ -499,10 +513,11 @@ func isNotationSigned(ctx context.Context, repo, digestStr string, searchConf Se ) bool { var referrers ispec.Index + httpClient := searchConf.SearchService.getHTTPClient() URL := fmt.Sprintf("%s/v2/%s/referrers/%s?artifactType=%s", searchConf.ServURL, repo, digestStr, common.ArtifactTypeNotation) - _, err := makeGETRequest(ctx, URL, username, password, + _, err := httpClient.makeGETRequest(ctx, URL, username, password, searchConf.VerifyTLS, searchConf.Debug, &referrers, searchConf.ResultWriter) if err != nil { return false @@ -518,12 +533,14 @@ func isNotationSigned(ctx context.Context, repo, digestStr string, searchConf Se func isCosignSigned(ctx context.Context, repo, digestStr string, searchConf SearchConfig, username, password string, ) bool { + httpClient := searchConf.SearchService.getHTTPClient() + var result any cosignTag := strings.Replace(digestStr, ":", "-", 1) + "." + common.CosignSignatureTagSuffix URL := fmt.Sprintf("%s/v2/%s/manifests/%s", searchConf.ServURL, repo, cosignTag) - _, err := makeGETRequest(ctx, URL, username, password, searchConf.VerifyTLS, + _, err := httpClient.makeGETRequest(ctx, URL, username, password, searchConf.VerifyTLS, searchConf.Debug, &result, searchConf.ResultWriter) if err == nil { return true @@ -535,7 +552,7 @@ func isCosignSigned(ctx context.Context, repo, digestStr string, searchConf Sear URL = fmt.Sprintf("%s/v2/%s/referrers/%s?artifactType=%s", searchConf.ServURL, repo, digestStr, artifactType) - _, err = makeGETRequest(ctx, URL, username, password, searchConf.VerifyTLS, + _, err = httpClient.makeGETRequest(ctx, URL, username, password, searchConf.VerifyTLS, searchConf.Debug, &referrers, searchConf.ResultWriter) if err != nil { return false diff --git a/pkg/cli/client/client_test.go b/pkg/cli/client/client_test.go index c97ad439..7bd42613 100644 --- a/pkg/cli/client/client_test.go +++ b/pkg/cli/client/client_test.go @@ -11,6 +11,7 @@ import ( "path" "path/filepath" "testing" + "time" . "github.com/smartystreets/goconvey/convey" "gopkg.in/resty.v1" @@ -21,6 +22,7 @@ import ( "zotregistry.dev/zot/v2/pkg/cli/client" extConf "zotregistry.dev/zot/v2/pkg/extensions/config" test "zotregistry.dev/zot/v2/pkg/test/common" + tlsutils "zotregistry.dev/zot/v2/pkg/test/tls" ) const ( @@ -31,19 +33,40 @@ const ( SecurePort2 = "8089" BaseSecureURL3 = "https://127.0.0.1:8090" SecurePort3 = "8090" - ServerCert = "../../../test/data/server.cert" - ServerKey = "../../../test/data/server.key" - CACert = "../../../test/data/ca.crt" - sourceCertsDir = "../../../test/data" certsDir1 = ".config/containers/certs.d/127.0.0.1:8088" ) func TestTLSWithAuth(t *testing.T) { Convey("Make a new controller", t, func() { - caCert, err := os.ReadFile(CACert) + // Generate certificates using tls library + tempDir := t.TempDir() + caOpts := &tlsutils.CertificateOptions{ + CommonName: "*", + NotAfter: time.Now().AddDate(10, 0, 0), + } + caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts) So(err, ShouldBeNil) + + caCertPath := path.Join(tempDir, "ca.crt") + caKeyPath := path.Join(tempDir, "ca.key") + err = os.WriteFile(caCertPath, caCertPEM, 0o600) + So(err, ShouldBeNil) + err = os.WriteFile(caKeyPath, caKeyPEM, 0o600) + So(err, ShouldBeNil) + + serverCertPath := path.Join(tempDir, "server.cert") + serverKeyPath := path.Join(tempDir, "server.key") + serverOpts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "*", + OrganizationalUnit: "TestServer", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts) + So(err, ShouldBeNil) + caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}) @@ -63,9 +86,9 @@ func TestTLSWithAuth(t *testing.T) { } conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, - CACert: CACert, + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } enable := true @@ -87,7 +110,22 @@ func TestTLSWithAuth(t *testing.T) { // Use the HOME that makeConfigFile set (temp directory) for certificates home := os.Getenv("HOME") destCertsDir := filepath.Join(home, certsDir1) - err := test.CopyTestKeysAndCerts(destCertsDir) + err := os.MkdirAll(destCertsDir, 0o755) + So(err, ShouldBeNil) + + // Write CA certificate to client certs directory (needed for server verification) + err = os.WriteFile(filepath.Join(destCertsDir, "ca.crt"), caCertPEM, 0o600) + So(err, ShouldBeNil) + + // Generate and write client certificate and key (needed for mTLS client authentication) + clientCertPath := filepath.Join(destCertsDir, "client.cert") + clientKeyPath := filepath.Join(destCertsDir, "client.key") + clientOpts := &tlsutils.CertificateOptions{ + CommonName: "testclient", + OrganizationalUnit: "TestClient", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateClientCertToFile(caCertPEM, caKeyPEM, clientCertPath, clientKeyPath, clientOpts) So(err, ShouldBeNil) defer os.RemoveAll(destCertsDir) @@ -112,7 +150,20 @@ func TestTLSWithAuth(t *testing.T) { // Ensure certificates are in the HOME directory that makeConfigFile set home = os.Getenv("HOME") destCertsDir = filepath.Join(home, certsDir1) - err = test.CopyTestKeysAndCerts(destCertsDir) + err = os.MkdirAll(destCertsDir, 0o755) + So(err, ShouldBeNil) + + // Write CA certificate to client certs directory (needed for server verification) + err = os.WriteFile(filepath.Join(destCertsDir, "ca.crt"), caCertPEM, 0o600) + So(err, ShouldBeNil) + + // Generate and write client certificate and key (needed for mTLS client authentication) + clientCertPath = filepath.Join(destCertsDir, "client.cert") + clientKeyPath = filepath.Join(destCertsDir, "client.key") + clientOpts = &tlsutils.CertificateOptions{ + CommonName: "testclient", + } + err = tlsutils.GenerateClientCertToFile(caCertPEM, caKeyPEM, clientCertPath, clientKeyPath, clientOpts) So(err, ShouldBeNil) imageCmd = client.NewImageCommand(client.NewSearchService()) @@ -144,10 +195,35 @@ func TestTLSWithAuth(t *testing.T) { func TestTLSWithoutAuth(t *testing.T) { Convey("Home certs - Make a new controller", t, func() { - caCert, err := os.ReadFile(CACert) + // Generate certificates using tls library + tempDir := t.TempDir() + caOpts := &tlsutils.CertificateOptions{ + CommonName: "*", + NotAfter: time.Now().AddDate(10, 0, 0), + } + caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts) So(err, ShouldBeNil) + + caCertPath := path.Join(tempDir, "ca.crt") + caKeyPath := path.Join(tempDir, "ca.key") + err = os.WriteFile(caCertPath, caCertPEM, 0o600) + So(err, ShouldBeNil) + err = os.WriteFile(caKeyPath, caKeyPEM, 0o600) + So(err, ShouldBeNil) + + serverCertPath := path.Join(tempDir, "server.cert") + serverKeyPath := path.Join(tempDir, "server.key") + serverOpts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "*", + OrganizationalUnit: "TestServer", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts) + So(err, ShouldBeNil) + caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}) @@ -156,9 +232,9 @@ func TestTLSWithoutAuth(t *testing.T) { conf := config.New() conf.HTTP.Port = SecurePort1 conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, - CACert: CACert, + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } enable := true @@ -181,7 +257,22 @@ func TestTLSWithoutAuth(t *testing.T) { home := os.Getenv("HOME") destCertsDir := filepath.Join(home, certsDir1) - err := test.CopyFiles(sourceCertsDir, destCertsDir) + err := os.MkdirAll(destCertsDir, 0o755) + So(err, ShouldBeNil) + + // Write CA certificate to client certs directory (needed for server verification) + err = os.WriteFile(filepath.Join(destCertsDir, "ca.crt"), caCertPEM, 0o600) + So(err, ShouldBeNil) + + // Generate and write client certificate and key (needed for mTLS client authentication) + clientCertPath := filepath.Join(destCertsDir, "client.cert") + clientKeyPath := filepath.Join(destCertsDir, "client.key") + clientOpts := &tlsutils.CertificateOptions{ + CommonName: "testclient", + OrganizationalUnit: "TestClient", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateClientCertToFile(caCertPEM, caKeyPEM, clientCertPath, clientKeyPath, clientOpts) So(err, ShouldBeNil) defer os.RemoveAll(destCertsDir) @@ -200,22 +291,53 @@ func TestTLSWithoutAuth(t *testing.T) { func TestTLSBadCerts(t *testing.T) { Convey("Make a new controller", t, func() { - caCert, err := os.ReadFile(CACert) + // Generate certificates using tls library + tempDir := t.TempDir() + caOpts := &tlsutils.CertificateOptions{ + CommonName: "*", + NotAfter: time.Now().AddDate(10, 0, 0), + } + caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts) So(err, ShouldBeNil) - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPath := path.Join(tempDir, "ca.crt") + caKeyPath := path.Join(tempDir, "ca.key") + err = os.WriteFile(caCertPath, caCertPEM, 0o600) + So(err, ShouldBeNil) + err = os.WriteFile(caKeyPath, caKeyPEM, 0o600) + So(err, ShouldBeNil) - resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}) + serverCertPath := path.Join(tempDir, "server.cert") + serverKeyPath := path.Join(tempDir, "server.key") + serverOpts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "*", + OrganizationalUnit: "TestServer", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts) + So(err, ShouldBeNil) + + // Use a different CA for the client to simulate bad certs + badCAOpts := &tlsutils.CertificateOptions{ + CommonName: "*", + } + badCACertPEM, _, err := tlsutils.GenerateCACert(badCAOpts) + So(err, ShouldBeNil) + + badCACertPool := x509.NewCertPool() + badCACertPool.AppendCertsFromPEM(badCACertPEM) + + resty.SetTLSClientConfig(&tls.Config{RootCAs: badCACertPool, MinVersion: tls.VersionTLS12}) defer func() { resty.SetTLSClientConfig(nil) }() conf := config.New() conf.HTTP.Port = SecurePort3 conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, - CACert: CACert, + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } ctlr := api.NewController(conf) diff --git a/pkg/cli/client/cve_cmd_internal_test.go b/pkg/cli/client/cve_cmd_internal_test.go index 6cd12fb2..7f562fe0 100644 --- a/pkg/cli/client/cve_cmd_internal_test.go +++ b/pkg/cli/client/cve_cmd_internal_test.go @@ -48,7 +48,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, "") - cmd := NewCVECommand(new(mockService)) + cmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -64,7 +64,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, "") - cmd := NewCVECommand(new(mockService)) + cmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -80,7 +80,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cmd := NewCVECommand(new(mockService)) + cmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -95,7 +95,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cmd := NewCVECommand(new(searchService)) + cmd := NewCVECommand(NewSearchService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -111,7 +111,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cmd := NewCVECommand(new(searchService)) + cmd := NewCVECommand(NewSearchService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -126,7 +126,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cmd := NewCVECommand(new(searchService)) + cmd := NewCVECommand(NewSearchService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -140,7 +140,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, fmt.Sprintf(`{"configs":[{"_name":"cvetest","url":"%s","showspinner":false}]}`, baseURL)) - cmd := NewCVECommand(new(mockService)) + cmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -171,7 +171,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, fmt.Sprintf(`{"configs":[{"_name":"cvetest","url":"%s","showspinner":false}]}`, baseURL)) - cmd := NewCVECommand(new(searchService)) + cmd := NewCVECommand(NewSearchService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -188,7 +188,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cveCmd := NewCVECommand(new(mockService)) + cveCmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cveCmd.SetOut(buff) cveCmd.SetErr(buff) @@ -207,7 +207,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cveCmd := NewCVECommand(new(mockService)) + cveCmd := NewCVECommand(newMockService()) cveCmd.SetOut(buff) cveCmd.SetErr(buff) cveCmd.SetArgs(args) @@ -224,7 +224,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cveCmd := NewCVECommand(new(mockService)) + cveCmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cveCmd.SetOut(buff) cveCmd.SetErr(buff) @@ -255,7 +255,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cveCmd := NewCVECommand(new(mockService)) + cveCmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cveCmd.SetOut(buff) cveCmd.SetErr(buff) @@ -293,7 +293,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cveCmd := NewCVECommand(new(mockService)) + cveCmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cveCmd.SetOut(buff) cveCmd.SetErr(buff) @@ -314,7 +314,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cveCmd := NewCVECommand(new(mockService)) + cveCmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cveCmd.SetOut(buff) cveCmd.SetErr(buff) @@ -333,7 +333,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cveCmd := NewCVECommand(new(mockService)) + cveCmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cveCmd.SetOut(buff) cveCmd.SetErr(buff) @@ -352,7 +352,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cveCmd := NewCVECommand(new(mockService)) + cveCmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cveCmd.SetOut(buff) cveCmd.SetErr(buff) @@ -370,7 +370,10 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - mockService := mockServiceForRetry{succeedOn: 2} // CVE info will be provided in 2nd attempt + mockService := mockServiceForRetry{ + mockService: *newMockService(), + succeedOn: 2, // CVE info will be provided in 2nd attempt + } cveCmd := NewCVECommand(&mockService) buff := bytes.NewBufferString("") cveCmd.SetOut(buff) @@ -392,7 +395,10 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - mockService := mockServiceForRetry{succeedOn: -1} // CVE info will be unavailable on all retries + mockService := mockServiceForRetry{ + mockService: *newMockService(), + succeedOn: -1, // CVE info will be unavailable on all retries + } cveCmd := NewCVECommand(&mockService) buff := bytes.NewBufferString("") cveCmd.SetOut(buff) @@ -414,7 +420,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cveCmd := NewCVECommand(new(mockService)) + cveCmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cveCmd.SetOut(buff) cveCmd.SetErr(buff) @@ -444,7 +450,7 @@ func TestSearchCVECmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"cvetest","showspinner":false}]}`) - cveCmd := NewCVECommand(new(mockService)) + cveCmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cveCmd.SetOut(buff) cveCmd.SetErr(buff) @@ -497,7 +503,7 @@ func TestCVECommandGQL(t *testing.T) { Convey("cveid", func() { args := []string{"affected", "CVE-1942", "--config", "cvetest"} - cmd := NewCVECommand(mockService{}) + cmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -517,7 +523,8 @@ func TestCVECommandGQL(t *testing.T) { _ = makeConfigFile(t, fmt.Sprintf(`{"configs":[{"_name":"cvetest","url":"%s","showspinner":false}]}`, baseURL)) - cmd := NewCVECommand(mockService{ + cmd := NewCVECommand(&mockService{ + httpClient: NewHTTPClient(), getTagsForCVEGQLFn: func(ctx context.Context, config SearchConfig, username, password, imageName, cveID string) (*zcommon.ImagesForCve, error, ) { @@ -545,7 +552,7 @@ func TestCVECommandGQL(t *testing.T) { Convey("fixed", func() { args := []string{"fixed", "image-name", "CVE-123", "--config", "cvetest"} - cmd := NewCVECommand(mockService{}) + cmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -565,7 +572,8 @@ func TestCVECommandGQL(t *testing.T) { _ = makeConfigFile(t, fmt.Sprintf(`{"configs":[{"_name":"cvetest","url":"%s","showspinner":false}]}`, baseURL)) - cmd := NewCVECommand(mockService{ + cmd := NewCVECommand(&mockService{ + httpClient: NewHTTPClient(), getFixedTagsForCVEGQLFn: func(ctx context.Context, config SearchConfig, username, password, imageName, cveID string) (*zcommon.ImageListWithCVEFixedResponse, error, ) { @@ -593,7 +601,7 @@ func TestCVECommandGQL(t *testing.T) { Convey("image", func() { args := []string{"list", "repo:tag", "--config", "cvetest"} - cmd := NewCVECommand(mockService{}) + cmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -616,7 +624,8 @@ func TestCVECommandGQL(t *testing.T) { _ = makeConfigFile(t, fmt.Sprintf(`{"configs":[{"_name":"cvetest","url":"%s","showspinner":false}]}`, baseURL)) - cmd := NewCVECommand(mockService{ + cmd := NewCVECommand(&mockService{ + httpClient: NewHTTPClient(), getCveByImageGQLFn: func(ctx context.Context, config SearchConfig, username, password, imageName, searchedCVE string) (*cveResult, error, ) { @@ -668,7 +677,7 @@ func TestCVECommandErrors(t *testing.T) { Convey("cveid", func() { args := []string{"affected", "CVE-1942"} - cmd := NewCVECommand(mockService{}) + cmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -680,7 +689,7 @@ func TestCVECommandErrors(t *testing.T) { Convey("cveid error", func() { // too many args args := []string{"too", "many", "args"} - cmd := NewImagesByCVEIDCommand(mockService{}) + cmd := NewImagesByCVEIDCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -690,7 +699,7 @@ func TestCVECommandErrors(t *testing.T) { // bad args args = []string{"not-a-cve-id"} - cmd = NewImagesByCVEIDCommand(mockService{}) + cmd = NewImagesByCVEIDCommand(newMockService()) buff = bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -700,7 +709,7 @@ func TestCVECommandErrors(t *testing.T) { // no URL args = []string{"CVE-1942"} - cmd = NewImagesByCVEIDCommand(mockService{}) + cmd = NewImagesByCVEIDCommand(newMockService()) buff = bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -711,7 +720,7 @@ func TestCVECommandErrors(t *testing.T) { Convey("fixed command", func() { args := []string{"fixed", "image-name", "CVE-123"} - cmd := NewCVECommand(mockService{}) + cmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -723,7 +732,7 @@ func TestCVECommandErrors(t *testing.T) { Convey("fixed command error", func() { // too many args args := []string{"too", "many", "args", "args"} - cmd := NewFixedTagsCommand(mockService{}) + cmd := NewFixedTagsCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -733,7 +742,7 @@ func TestCVECommandErrors(t *testing.T) { // bad args args = []string{"repo-tag-instead-of-just-repo:fail-here", "CVE-123"} - cmd = NewFixedTagsCommand(mockService{}) + cmd = NewFixedTagsCommand(newMockService()) buff = bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -743,7 +752,7 @@ func TestCVECommandErrors(t *testing.T) { // no URL args = []string{"CVE-1942"} - cmd = NewFixedTagsCommand(mockService{}) + cmd = NewFixedTagsCommand(newMockService()) buff = bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -754,7 +763,7 @@ func TestCVECommandErrors(t *testing.T) { Convey("image", func() { args := []string{"list", "repo:tag"} - cmd := NewCVECommand(mockService{}) + cmd := NewCVECommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -766,7 +775,7 @@ func TestCVECommandErrors(t *testing.T) { Convey("image command error", func() { // too many args args := []string{"too", "many", "args", "args"} - cmd := NewCveForImageCommand(mockService{}) + cmd := NewCveForImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -776,7 +785,7 @@ func TestCVECommandErrors(t *testing.T) { // bad args args = []string{"repo-tag-instead-of-just-repo:fail-here", "CVE-123"} - cmd = NewCveForImageCommand(mockService{}) + cmd = NewCveForImageCommand(newMockService()) buff = bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -786,7 +795,7 @@ func TestCVECommandErrors(t *testing.T) { // no URL args = []string{"CVE-1942"} - cmd = NewCveForImageCommand(mockService{}) + cmd = NewCveForImageCommand(newMockService()) buff = bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) diff --git a/pkg/cli/client/discover.go b/pkg/cli/client/discover.go index 072ad92b..d12b62cd 100644 --- a/pkg/cli/client/discover.go +++ b/pkg/cli/client/discover.go @@ -101,7 +101,8 @@ func CheckExtEndPointQuery(config SearchConfig, requiredQueries ...GQLQuery) err discoverResponse := &distext.ExtensionList{} - _, err = makeGETRequest(ctx, discoverEndPoint, username, password, config.VerifyTLS, + _, err = config.SearchService.getHTTPClient().makeGETRequest( + ctx, discoverEndPoint, username, password, config.VerifyTLS, config.Debug, &discoverResponse, config.ResultWriter) if err != nil { return err @@ -152,7 +153,8 @@ func CheckExtEndPointQuery(config SearchConfig, requiredQueries ...GQLQuery) err queryResponse := &schemaList{} - err = makeGraphQLRequest(ctx, searchEndPoint, schemaQuery, username, password, config.VerifyTLS, + err = config.SearchService.getHTTPClient().makeGraphQLRequest( + ctx, searchEndPoint, schemaQuery, username, password, config.VerifyTLS, config.Debug, queryResponse, config.ResultWriter) if err != nil { return fmt.Errorf("gql query failed: %w", err) diff --git a/pkg/cli/client/elevated_test.go b/pkg/cli/client/elevated_test.go index 67f5e297..d2f1e46c 100644 --- a/pkg/cli/client/elevated_test.go +++ b/pkg/cli/client/elevated_test.go @@ -9,8 +9,10 @@ import ( "fmt" "os" "os/exec" + "path" "path/filepath" "testing" + "time" . "github.com/smartystreets/goconvey/convey" "gopkg.in/resty.v1" @@ -20,65 +22,107 @@ import ( "zotregistry.dev/zot/v2/pkg/api/constants" "zotregistry.dev/zot/v2/pkg/cli/client" test "zotregistry.dev/zot/v2/pkg/test/common" + tlsutils "zotregistry.dev/zot/v2/pkg/test/tls" +) + +const ( + privilegedCertsDir = "/etc/containers/certs.d/127.0.0.1:8089" ) func TestElevatedPrivilegesTLSNewControllerPrivilegedCert(t *testing.T) { Convey("Privileged certs - Make a new controller", t, func() { - //nolint: noctx // old code, no context available - cmd := exec.Command("mkdir", "-p", "/etc/containers/certs.d/127.0.0.1:8089/") //nolint: gosec + // Generate certificates using tls library + tempDir := t.TempDir() + caOpts := &tlsutils.CertificateOptions{ + CommonName: "*", + NotAfter: time.Now().AddDate(10, 0, 0), + } + caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts) + So(err, ShouldBeNil) - _, err := cmd.Output() + caCertPath := path.Join(tempDir, "ca.crt") + caKeyPath := path.Join(tempDir, "ca.key") + err = os.WriteFile(caCertPath, caCertPEM, 0o600) + So(err, ShouldBeNil) + err = os.WriteFile(caKeyPath, caKeyPEM, 0o600) + So(err, ShouldBeNil) + + serverCertPath := path.Join(tempDir, "server.cert") + serverKeyPath := path.Join(tempDir, "server.key") + serverOpts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "*", + OrganizationalUnit: "TestServer", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts) + So(err, ShouldBeNil) + + // Generate client certificate + clientCertPath := path.Join(tempDir, "client.cert") + clientKeyPath := path.Join(tempDir, "client.key") + clientOpts := &tlsutils.CertificateOptions{ + CommonName: "testclient", + OrganizationalUnit: "TestClient", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateClientCertToFile(caCertPEM, caKeyPEM, clientCertPath, clientKeyPath, clientOpts) + So(err, ShouldBeNil) + + //nolint: noctx // old code, no context available + cmd := exec.Command("mkdir", "-p", privilegedCertsDir+"/") //nolint: gosec + + _, err = cmd.Output() if err != nil { panic(err) } //nolint: noctx // old code, no context available - defer exec.Command("rm", "-rf", "/etc/containers/certs.d/127.0.0.1:8089/") + defer exec.Command("rm", "-rf", privilegedCertsDir+"/") - workDir, _ := os.Getwd() - _ = os.Chdir("../../../test/data") - - clientGlob, _ := filepath.Glob("client.*") - caGlob, _ := filepath.Glob("ca.*") - - for _, file := range clientGlob { - //nolint: noctx // old code, no context available - cmd = exec.Command("cp", file, "/etc/containers/certs.d/127.0.0.1:8089/") - - res, err := cmd.CombinedOutput() - if err != nil { - panic(string(res)) - } + // Copy generated certificates to privileged location + //nolint: noctx // old code, no context available + cmd = exec.Command("cp", clientCertPath, privilegedCertsDir+"/") + res, err := cmd.CombinedOutput() + if err != nil { + panic(string(res)) } - for _, file := range caGlob { - //nolint: noctx // old code, no context available - cmd = exec.Command("cp", file, "/etc/containers/certs.d/127.0.0.1:8089/") - - res, err := cmd.CombinedOutput() - if err != nil { - panic(string(res)) - } + //nolint: noctx // old code, no context available + cmd = exec.Command("cp", clientKeyPath, privilegedCertsDir+"/") + res, err = cmd.CombinedOutput() + if err != nil { + panic(string(res)) } - allGlob, _ := filepath.Glob("/etc/containers/certs.d/127.0.0.1:8089/*.key") + //nolint: noctx // old code, no context available + cmd = exec.Command("cp", caCertPath, privilegedCertsDir+"/") + res, err = cmd.CombinedOutput() + if err != nil { + panic(string(res)) + } + + //nolint: noctx // old code, no context available + cmd = exec.Command("cp", caKeyPath, privilegedCertsDir+"/") + res, err = cmd.CombinedOutput() + if err != nil { + panic(string(res)) + } + + allGlob, _ := filepath.Glob(privilegedCertsDir + "/*.key") for _, file := range allGlob { //nolint: noctx // old code, no context available cmd = exec.Command("chmod", "a=rwx", file) - res, err := cmd.CombinedOutput() + res, err = cmd.CombinedOutput() if err != nil { panic(string(res)) } } - _ = os.Chdir(workDir) - - caCert, err := os.ReadFile(CACert) - So(err, ShouldBeNil) caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}) @@ -87,9 +131,9 @@ func TestElevatedPrivilegesTLSNewControllerPrivilegedCert(t *testing.T) { conf := config.New() conf.HTTP.Port = SecurePort2 conf.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, - CACert: CACert, + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } ctlr := api.NewController(conf) diff --git a/pkg/cli/client/gql_queries_test.go b/pkg/cli/client/gql_queries_test.go index 4f89c25f..4d52cdc4 100644 --- a/pkg/cli/client/gql_queries_test.go +++ b/pkg/cli/client/gql_queries_test.go @@ -37,11 +37,12 @@ func TestGQLQueries(t *testing.T) { defer cm.StopServer() searchConfig := client.SearchConfig{ - ServURL: baseURL, - User: "", - VerifyTLS: false, - Debug: false, - ResultWriter: io.Discard, + ServURL: baseURL, + User: "", + VerifyTLS: false, + Debug: false, + ResultWriter: io.Discard, + SearchService: client.NewSearchService(), } Convey("Make sure the current CLI used the right queries in case they change", t, func() { diff --git a/pkg/cli/client/image_cmd_internal_test.go b/pkg/cli/client/image_cmd_internal_test.go index f2d7edfa..4b79119d 100644 --- a/pkg/cli/client/image_cmd_internal_test.go +++ b/pkg/cli/client/image_cmd_internal_test.go @@ -37,7 +37,7 @@ func TestSearchImageCmd(t *testing.T) { _ = makeConfigFile(t, "") - cmd := NewImageCommand(new(mockService)) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -52,7 +52,7 @@ func TestSearchImageCmd(t *testing.T) { _ = makeConfigFile(t, "") - cmd := NewImageCommand(new(mockService)) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -69,7 +69,7 @@ func TestSearchImageCmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"imagetest","showspinner":false}]}`) - cmd := NewImageCommand(new(mockService)) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -89,7 +89,7 @@ func TestSearchImageCmd(t *testing.T) { panic(err) } - cmd := NewImageCommand(new(mockService)) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -113,7 +113,7 @@ func TestSearchImageCmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"imagetest","showspinner":false}]}`) - cmd := NewImageCommand(new(mockService)) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -127,7 +127,7 @@ func TestSearchImageCmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"imagetest","showspinner":false}]}`) - cmd := NewImageCommand(new(searchService)) + cmd := NewImageCommand(NewSearchService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -143,7 +143,7 @@ func TestSearchImageCmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"imagetest","showspinner":false}]}`) - cmd := NewImageCommand(new(searchService)) + cmd := NewImageCommand(NewSearchService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -157,7 +157,7 @@ func TestSearchImageCmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"imagetest","showspinner":false}]}`) - cmd := NewImageCommand(new(searchService)) + cmd := NewImageCommand(NewSearchService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -173,7 +173,7 @@ func TestSearchImageCmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"imagetest","showspinner":false}]}`) - cmd := NewImageCommand(new(searchService)) + cmd := NewImageCommand(NewSearchService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -187,7 +187,7 @@ func TestSearchImageCmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"imagetest","url":"https://test-url.com","showspinner":false}]}`) - cmd := NewImageCommand(new(mockService)) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -205,7 +205,7 @@ func TestSearchImageCmd(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"imagetest","showspinner":false}]}`) - imageCmd := NewImageCommand(new(mockService)) + imageCmd := NewImageCommand(newMockService()) buff := &bytes.Buffer{} imageCmd.SetOut(buff) imageCmd.SetErr(buff) @@ -219,7 +219,7 @@ func TestSearchImageCmd(t *testing.T) { }) Convey("Test image by digest", t, func() { - searchConfig := getTestSearchConfig("http://127.0.0.1:8080", new(mockService)) + searchConfig := getTestSearchConfig("http://127.0.0.1:8080", newMockService()) buff := &bytes.Buffer{} searchConfig.ResultWriter = buff err := SearchImagesByDigest(searchConfig, "6e2f80bf") @@ -234,7 +234,7 @@ func TestSearchImageCmd(t *testing.T) { } func TestListRepos(t *testing.T) { - searchConfig := getTestSearchConfig("https://test-url.com", new(mockService)) + searchConfig := getTestSearchConfig("https://test-url.com", newMockService()) Convey("Test listing repositories", t, func() { buff := &bytes.Buffer{} @@ -248,7 +248,7 @@ func TestListRepos(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"config-test","url":"https://test-url.com","showspinner":false}]}`) - cmd := NewRepoCommand(new(searchService)) + cmd := NewRepoCommand(NewSearchService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) @@ -272,7 +272,7 @@ func TestListRepos(t *testing.T) { panic(err) } - cmd := NewRepoCommand(new(mockService)) + cmd := NewRepoCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -297,7 +297,7 @@ func TestListRepos(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"config-test", "url":"https://invalid.invalid","showspinner":false}]}`) - cmd := NewRepoCommand(new(searchService)) + cmd := NewRepoCommand(NewSearchService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -311,7 +311,7 @@ func TestListRepos(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"config-test","url":"https://test-url.com","showspinner":false}]}`) - cmd := NewRepoCommand(new(mockService)) + cmd := NewRepoCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -325,7 +325,7 @@ func TestListRepos(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"config-test","url":"","showspinner":false}]}`) - cmd := NewRepoCommand(new(mockService)) + cmd := NewRepoCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -340,7 +340,7 @@ func TestListRepos(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"config-test", "url":"https://test-url.com","showspinner":invalid}]}`) - cmd := NewRepoCommand(new(mockService)) + cmd := NewRepoCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -355,7 +355,7 @@ func TestListRepos(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"config-test", "verify-tls":"invalid", "url":"https://test-url.com","showspinner":false}]}`) - cmd := NewRepoCommand(new(mockService)) + cmd := NewRepoCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -371,7 +371,7 @@ func TestOutputFormat(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"imagetest","url":"https://test-url.com","showspinner":false}]}`) - cmd := NewImageCommand(new(mockService)) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -389,7 +389,7 @@ func TestOutputFormat(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"imagetest","url":"https://test-url.com","showspinner":false}]}`) - cmd := NewImageCommand(new(mockService)) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -419,7 +419,7 @@ func TestOutputFormat(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"imagetest","url":"https://test-url.com","showspinner":false}]}`) - cmd := NewImageCommand(new(mockService)) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -457,7 +457,7 @@ func TestOutputFormat(t *testing.T) { `"url":"https://test-url.com","showspinner":false}]}`, ) - cmd := NewImageCommand(new(mockService)) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -493,7 +493,7 @@ func TestOutputFormat(t *testing.T) { _ = makeConfigFile(t, `{"configs":[{"_name":"imagetest","url":"https://test-url.com","showspinner":false}]}`) - cmd := NewImageCommand(new(mockService)) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -797,7 +797,7 @@ func TestImagesCommandGQL(t *testing.T) { _ = makeConfigFile(t, fmt.Sprintf(`{"configs":[{"_name":"imagetest","url":"%s","showspinner":false}]}`, baseURL)) args := []string{"cve", "repo:vuln", "--config", "imagetest"} - cmd := NewImageCommand(mockService{}) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -818,7 +818,8 @@ func TestImagesCommandGQL(t *testing.T) { baseURL)) args := []string{"cve", "repo:vuln", "--config", "imagetest"} - cmd := NewImageCommand(mockService{ + cmd := NewImageCommand(&mockService{ + httpClient: NewHTTPClient(), getCveByImageGQLFn: func(ctx context.Context, config SearchConfig, username, password, imageName, searchedCVE string) (*cveResult, error, ) { @@ -895,7 +896,7 @@ func TestImagesCommandGQL(t *testing.T) { So(err, ShouldNotBeNil) args = []string{"cve", "repo:vuln"} - cmd = NewImageCommand(mockService{}) + cmd = NewImageCommand(newMockService()) buff = bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -1028,7 +1029,7 @@ func TestImageCommandREST(t *testing.T) { _ = makeConfigFile(t, fmt.Sprintf(`{"configs":[{"_name":"imagetest","url":"%s","showspinner":false}]}`, baseURL)) - cmd := NewImageCommand(mockService{}) + cmd := NewImageCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -1093,9 +1094,22 @@ type mockService struct { getCVEDiffListGQLFn func(ctx context.Context, config SearchConfig, username, password string, minuend, subtrahend ImageIdentifier, ) (*cveDiffListResp, error) + + httpClient *HTTPClient } -func (service mockService) getCVEDiffListGQL(ctx context.Context, config SearchConfig, username, password string, +// newMockService creates a new mockService with httpClient initialized. +func newMockService() *mockService { + return &mockService{ + httpClient: NewHTTPClient(), + } +} + +func (service *mockService) getHTTPClient() *HTTPClient { + return service.httpClient +} + +func (service *mockService) getCVEDiffListGQL(ctx context.Context, config SearchConfig, username, password string, minuend, subtrahend ImageIdentifier, ) (*cveDiffListResp, error) { if service.getCVEDiffListGQLFn != nil { @@ -1105,7 +1119,7 @@ func (service mockService) getCVEDiffListGQL(ctx context.Context, config SearchC return &cveDiffListResp{}, nil } -func (service mockService) getRepos(ctx context.Context, config SearchConfig, username, +func (service *mockService) getRepos(ctx context.Context, config SearchConfig, username, password string, channel chan stringResult, wtgrp *sync.WaitGroup, ) { defer wtgrp.Done() @@ -1117,7 +1131,7 @@ func (service mockService) getRepos(ctx context.Context, config SearchConfig, us fmt.Fprintln(config.ResultWriter, "repo2") } -func (service mockService) getReferrers(ctx context.Context, config SearchConfig, username, password string, +func (service *mockService) getReferrers(ctx context.Context, config SearchConfig, username, password string, repo, digest string, ) (referrersResult, error) { if service.getReferrersFn != nil { @@ -1134,7 +1148,7 @@ func (service mockService) getReferrers(ctx context.Context, config SearchConfig }, nil } -func (service mockService) globalSearchGQL(ctx context.Context, config SearchConfig, username, password string, +func (service *mockService) globalSearchGQL(ctx context.Context, config SearchConfig, username, password string, query string, ) (*common.GlobalSearch, error) { if service.globalSearchGQLFn != nil { @@ -1166,7 +1180,7 @@ func (service mockService) globalSearchGQL(ctx context.Context, config SearchCon }, nil } -func (service mockService) getReferrersGQL(ctx context.Context, config SearchConfig, username, password string, +func (service *mockService) getReferrersGQL(ctx context.Context, config SearchConfig, username, password string, repo, digest string, ) (*common.ReferrersResp, error) { if service.getReferrersGQLFn != nil { @@ -1187,7 +1201,7 @@ func (service mockService) getReferrersGQL(ctx context.Context, config SearchCon }, nil } -func (service mockService) getDerivedImageListGQL(ctx context.Context, config SearchConfig, username, password string, +func (service *mockService) getDerivedImageListGQL(ctx context.Context, config SearchConfig, username, password string, derivedImage string, ) (*common.DerivedImageListResponse, error) { if service.getDerivedImageListGQLFn != nil { @@ -1215,7 +1229,7 @@ func (service mockService) getDerivedImageListGQL(ctx context.Context, config Se return imageListGQLResponse, nil } -func (service mockService) getBaseImageListGQL(ctx context.Context, config SearchConfig, username, password string, +func (service *mockService) getBaseImageListGQL(ctx context.Context, config SearchConfig, username, password string, baseImage string, ) (*common.BaseImageListResponse, error) { if service.getBaseImageListGQLFn != nil { @@ -1243,7 +1257,7 @@ func (service mockService) getBaseImageListGQL(ctx context.Context, config Searc return imageListGQLResponse, nil } -func (service mockService) getImagesGQL(ctx context.Context, config SearchConfig, username, password string, +func (service *mockService) getImagesGQL(ctx context.Context, config SearchConfig, username, password string, imageName string, ) (*common.ImageListResponse, error) { if service.getImagesGQLFn != nil { @@ -1273,7 +1287,7 @@ func (service mockService) getImagesGQL(ctx context.Context, config SearchConfig return imageListGQLResponse, nil } -func (service mockService) getImagesForDigestGQL(ctx context.Context, config SearchConfig, username, password string, +func (service *mockService) getImagesForDigestGQL(ctx context.Context, config SearchConfig, username, password string, digest string, ) (*common.ImagesForDigest, error) { if service.getImagesForDigestGQLFn != nil { @@ -1303,7 +1317,7 @@ func (service mockService) getImagesForDigestGQL(ctx context.Context, config Sea return imageListGQLResponse, nil } -func (service mockService) getTagsForCVEGQL(ctx context.Context, config SearchConfig, username, password, +func (service *mockService) getTagsForCVEGQL(ctx context.Context, config SearchConfig, username, password, imageName, cveID string, ) (*common.ImagesForCve, error) { if service.getTagsForCVEGQLFn != nil { @@ -1329,7 +1343,7 @@ func (service mockService) getTagsForCVEGQL(ctx context.Context, config SearchCo return images, nil } -func (service mockService) getFixedTagsForCVEGQL(ctx context.Context, config SearchConfig, username, password, +func (service *mockService) getFixedTagsForCVEGQL(ctx context.Context, config SearchConfig, username, password, imageName, cveID string, ) (*common.ImageListWithCVEFixedResponse, error) { if service.getFixedTagsForCVEGQLFn != nil { @@ -1351,7 +1365,7 @@ func (service mockService) getFixedTagsForCVEGQL(ctx context.Context, config Sea return fixedTags, nil } -func (service mockService) getCveByImageGQL(ctx context.Context, config SearchConfig, username, password, +func (service *mockService) getCveByImageGQL(ctx context.Context, config SearchConfig, username, password, imageName, searchedCVE string, ) (*cveResult, error) { if service.getCveByImageGQLFn != nil { @@ -1411,7 +1425,7 @@ func (service mockService) getMockedImageByName(imageName string) imageStruct { return image } -func (service mockService) getAllImages(ctx context.Context, config SearchConfig, username, password string, +func (service *mockService) getAllImages(ctx context.Context, config SearchConfig, username, password string, channel chan stringResult, wtgrp *sync.WaitGroup, ) { defer wtgrp.Done() @@ -1449,7 +1463,7 @@ func (service mockService) getAllImages(ctx context.Context, config SearchConfig channel <- stringResult{str, nil} } -func (service mockService) getImageByName(ctx context.Context, config SearchConfig, +func (service *mockService) getImageByName(ctx context.Context, config SearchConfig, username, password, imageName string, channel chan stringResult, wtgrp *sync.WaitGroup, ) { defer wtgrp.Done() @@ -1487,7 +1501,7 @@ func (service mockService) getImageByName(ctx context.Context, config SearchConf channel <- stringResult{str, nil} } -func (service mockService) getImagesByDigest(ctx context.Context, config SearchConfig, username, +func (service *mockService) getImagesByDigest(ctx context.Context, config SearchConfig, username, password, digest string, rch chan stringResult, wtgrp *sync.WaitGroup, ) { if service.getImagesByDigestFn != nil { diff --git a/pkg/cli/client/search_cmd_internal_test.go b/pkg/cli/client/search_cmd_internal_test.go index 102ef294..7cb56066 100644 --- a/pkg/cli/client/search_cmd_internal_test.go +++ b/pkg/cli/client/search_cmd_internal_test.go @@ -44,7 +44,7 @@ func TestSearchCommandGQL(t *testing.T) { Convey("query", func() { args := []string{"query", "repo/al", "--config", "searchtest"} - cmd := NewSearchCommand(mockService{}) + cmd := NewSearchCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) @@ -62,7 +62,7 @@ func TestSearchCommandGQL(t *testing.T) { Convey("query command errors", func() { // no url args := []string{"repo/al", "--config", "searchtest"} - cmd := NewSearchQueryCommand(mockService{}) + cmd := NewSearchQueryCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -76,7 +76,7 @@ func TestSearchCommandGQL(t *testing.T) { So(err, ShouldBeNil) args := []string{"subject", "repo:tag", "--config", "searchtest"} - cmd := NewSearchCommand(mockService{}) + cmd := NewSearchCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) @@ -93,7 +93,7 @@ func TestSearchCommandGQL(t *testing.T) { Convey("subject command errors", func() { // no url args := []string{"repo:tag", "--config", "searchtest"} - cmd := NewSearchSubjectCommand(mockService{}) + cmd := NewSearchSubjectCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) cmd.SetErr(buff) @@ -123,7 +123,7 @@ func TestSearchCommandREST(t *testing.T) { Convey("query", func() { args := []string{"query", "repo/al", "--config", "searchtest"} - cmd := NewSearchCommand(mockService{}) + cmd := NewSearchCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) @@ -138,7 +138,7 @@ func TestSearchCommandREST(t *testing.T) { So(err, ShouldBeNil) args := []string{"subject", "repo:tag", "--config", "searchtest"} - cmd := NewSearchCommand(mockService{}) + cmd := NewSearchCommand(newMockService()) buff := bytes.NewBufferString("") cmd.SetOut(buff) diff --git a/pkg/cli/client/search_functions_internal_test.go b/pkg/cli/client/search_functions_internal_test.go index a2560e84..414853a6 100644 --- a/pkg/cli/client/search_functions_internal_test.go +++ b/pkg/cli/client/search_functions_internal_test.go @@ -26,7 +26,8 @@ import ( func TestSearchAllImages(t *testing.T) { Convey("SearchAllImages", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getAllImagesFn: func(ctx context.Context, config SearchConfig, username, password string, channel chan stringResult, wtgrp *sync.WaitGroup, ) { @@ -48,7 +49,8 @@ func TestSearchAllImages(t *testing.T) { func TestSearchAllImagesGQL(t *testing.T) { Convey("SearchAllImagesGQL", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getImagesGQLFn: func(ctx context.Context, config SearchConfig, username, password, imageName string, ) (*common.ImageListResponse, error) { return &common.ImageListResponse{ImageList: common.ImageList{ @@ -69,7 +71,8 @@ func TestSearchAllImagesGQL(t *testing.T) { Convey("SearchAllImagesGQL error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getImagesGQLFn: func(ctx context.Context, config SearchConfig, username, password, imageName string, ) (*common.ImageListResponse, error) { return &common.ImageListResponse{ImageList: common.ImageList{ @@ -88,7 +91,8 @@ func TestSearchAllImagesGQL(t *testing.T) { func TestSearchImageByName(t *testing.T) { Convey("SearchImageByName", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getImageByNameFn: func(ctx context.Context, config SearchConfig, username string, password string, imageName string, channel chan stringResult, wtgrp *sync.WaitGroup, ) { @@ -108,7 +112,8 @@ func TestSearchImageByName(t *testing.T) { Convey("SearchImageByName error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getImageByNameFn: func(ctx context.Context, config SearchConfig, username string, password string, imageName string, channel chan stringResult, wtgrp *sync.WaitGroup, ) { @@ -124,7 +129,8 @@ func TestSearchImageByName(t *testing.T) { func TestSearchImageByNameGQL(t *testing.T) { Convey("SearchImageByNameGQL", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getImagesGQLFn: func(ctx context.Context, config SearchConfig, username, password, imageName string, ) (*common.ImageListResponse, error) { return &common.ImageListResponse{ImageList: common.ImageList{ @@ -145,7 +151,8 @@ func TestSearchImageByNameGQL(t *testing.T) { Convey("SearchImageByNameGQL error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getImagesGQLFn: func(ctx context.Context, config SearchConfig, username, password, imageName string, ) (*common.ImageListResponse, error) { return &common.ImageListResponse{ImageList: common.ImageList{ @@ -164,7 +171,8 @@ func TestSearchImageByNameGQL(t *testing.T) { func TestSearchImagesByDigest(t *testing.T) { Convey("SearchImagesByDigest", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getImagesByDigestFn: func(ctx context.Context, config SearchConfig, username string, password string, digest string, rch chan stringResult, wtgrp *sync.WaitGroup, ) { @@ -184,7 +192,8 @@ func TestSearchImagesByDigest(t *testing.T) { Convey("SearchImagesByDigest error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getImagesByDigestFn: func(ctx context.Context, config SearchConfig, username string, password string, digest string, rch chan stringResult, wtgrp *sync.WaitGroup, ) { @@ -200,7 +209,8 @@ func TestSearchImagesByDigest(t *testing.T) { func TestSearchDerivedImageListGQL(t *testing.T) { Convey("SearchDerivedImageListGQL", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getDerivedImageListGQLFn: func(ctx context.Context, config SearchConfig, username string, password string, derivedImage string) (*common.DerivedImageListResponse, error, ) { @@ -224,7 +234,8 @@ func TestSearchDerivedImageListGQL(t *testing.T) { Convey("SearchDerivedImageListGQL error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getDerivedImageListGQLFn: func(ctx context.Context, config SearchConfig, username string, password string, derivedImage string) (*common.DerivedImageListResponse, error, ) { @@ -242,7 +253,8 @@ func TestSearchDerivedImageListGQL(t *testing.T) { func TestSearchBaseImageListGQL(t *testing.T) { Convey("SearchBaseImageListGQL", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getBaseImageListGQLFn: func(ctx context.Context, config SearchConfig, username string, password string, derivedImage string) (*common.BaseImageListResponse, error, ) { @@ -264,7 +276,8 @@ func TestSearchBaseImageListGQL(t *testing.T) { Convey("SearchBaseImageListGQL error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getBaseImageListGQLFn: func(ctx context.Context, config SearchConfig, username string, password string, derivedImage string) (*common.BaseImageListResponse, error, ) { @@ -282,7 +295,8 @@ func TestSearchBaseImageListGQL(t *testing.T) { func TestSearchImagesForDigestGQL(t *testing.T) { Convey("SearchImagesForDigestGQL", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getImagesForDigestGQLFn: func(ctx context.Context, config SearchConfig, username string, password string, digest string) (*common.ImagesForDigest, error, ) { @@ -304,7 +318,8 @@ func TestSearchImagesForDigestGQL(t *testing.T) { Convey("SearchImagesForDigestGQL error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getImagesForDigestGQLFn: func(ctx context.Context, config SearchConfig, username string, password string, digest string) (*common.ImagesForDigest, error, ) { @@ -322,7 +337,8 @@ func TestSearchImagesForDigestGQL(t *testing.T) { func TestSearchCVEForImageGQL(t *testing.T) { Convey("SearchCVEForImageGQL normal mode", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getCveByImageGQLFn: func(ctx context.Context, config SearchConfig, username string, password string, imageName string, searchedCVE string) (*cveResult, error, ) { @@ -406,7 +422,8 @@ func TestSearchCVEForImageGQL(t *testing.T) { Convey("SearchCVEForImageGQL verbose mode", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getCveByImageGQLFn: func(ctx context.Context, config SearchConfig, username string, password string, imageName string, searchedCVE string) (*cveResult, error, ) { @@ -530,7 +547,8 @@ func TestSearchCVEForImageGQL(t *testing.T) { Convey("SearchCVEForImageGQL with injected error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getCveByImageGQLFn: func(ctx context.Context, config SearchConfig, username string, password string, imageName string, searchedCVE string) (*cveResult, error, ) { @@ -546,7 +564,8 @@ func TestSearchCVEForImageGQL(t *testing.T) { func TestSearchImagesByCVEIDGQL(t *testing.T) { Convey("SearchImagesByCVEIDGQL", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getTagsForCVEGQLFn: func(ctx context.Context, config SearchConfig, username, password, imageName, cveID string) (*common.ImagesForCve, error, ) { @@ -572,7 +591,8 @@ func TestSearchImagesByCVEIDGQL(t *testing.T) { Convey("SearchImagesByCVEIDGQL error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getTagsForCVEGQLFn: func(ctx context.Context, config SearchConfig, username, password, imageName, cveID string) (*common.ImagesForCve, error, ) { @@ -592,7 +612,8 @@ func TestSearchImagesByCVEIDGQL(t *testing.T) { func TestSearchFixedTagsGQL(t *testing.T) { Convey("SearchFixedTagsGQL", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getFixedTagsForCVEGQLFn: func(ctx context.Context, config SearchConfig, username, password, imageName, cveID string) (*common.ImageListWithCVEFixedResponse, error, ) { @@ -616,7 +637,8 @@ func TestSearchFixedTagsGQL(t *testing.T) { Convey("SearchFixedTagsGQL error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getFixedTagsForCVEGQLFn: func(ctx context.Context, config SearchConfig, username, password, imageName, cveID string) (*common.ImageListWithCVEFixedResponse, error, ) { @@ -636,7 +658,8 @@ func TestSearchFixedTagsGQL(t *testing.T) { func TestSearchReferrersGQL(t *testing.T) { Convey("SearchReferrersGQL", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getReferrersGQLFn: func(ctx context.Context, config SearchConfig, username, password, repo, digest string) (*common.ReferrersResp, error, ) { @@ -664,7 +687,8 @@ func TestSearchReferrersGQL(t *testing.T) { Convey("SearchReferrersGQL error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getReferrersGQLFn: func(ctx context.Context, config SearchConfig, username, password, repo, digest string) (*common.ReferrersResp, error, ) { @@ -680,7 +704,8 @@ func TestSearchReferrersGQL(t *testing.T) { func TestGlobalSearchGQL(t *testing.T) { Convey("GlobalSearchGQL", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), globalSearchGQLFn: func(ctx context.Context, config SearchConfig, username, password, query string) (*common.GlobalSearch, error, ) { @@ -705,7 +730,8 @@ func TestGlobalSearchGQL(t *testing.T) { Convey("GlobalSearchGQL error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), globalSearchGQLFn: func(ctx context.Context, config SearchConfig, username, password, query string) (*common.GlobalSearch, error, ) { @@ -721,7 +747,8 @@ func TestGlobalSearchGQL(t *testing.T) { func TestSearchReferrers(t *testing.T) { Convey("SearchReferrers", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getReferrersFn: func(ctx context.Context, config SearchConfig, username string, password string, repo string, digest string) (referrersResult, error, ) { @@ -747,7 +774,8 @@ func TestSearchReferrers(t *testing.T) { Convey("SearchReferrers error", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{ + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient(), getReferrersFn: func(ctx context.Context, config SearchConfig, username string, password string, repo string, digest string) (referrersResult, error, ) { @@ -763,7 +791,8 @@ func TestSearchReferrers(t *testing.T) { func TestSearchRepos(t *testing.T) { Convey("SearchRepos", t, func() { buff := bytes.NewBufferString("") - searchConfig := getMockSearchConfig(buff, mockService{}) + searchConfig := getMockSearchConfig(buff, &mockService{ + httpClient: NewHTTPClient()}) err := SearchRepos(searchConfig) So(err, ShouldBeNil) @@ -775,7 +804,7 @@ func TestSearchRepos(t *testing.T) { }) } -func getMockSearchConfig(buff *bytes.Buffer, mockService mockService) SearchConfig { +func getMockSearchConfig(buff *bytes.Buffer, mockService *mockService) SearchConfig { return SearchConfig{ ResultWriter: buff, User: "", @@ -909,18 +938,20 @@ func TestUtils(t *testing.T) { Convey("CheckExtEndPointQuery", t, func() { // invalid url err := CheckExtEndPointQuery(SearchConfig{ - User: "", - ServURL: "bad-url", + User: "", + ServURL: "bad-url", + SearchService: NewSearchService(), }) So(err, ShouldNotBeNil) // good url but no connection err = CheckExtEndPointQuery(SearchConfig{ - User: "", - ServURL: "http://127.0.0.1:5000", - VerifyTLS: false, - Debug: false, - ResultWriter: io.Discard, + User: "", + ServURL: "http://127.0.0.1:5000", + VerifyTLS: false, + Debug: false, + ResultWriter: io.Discard, + SearchService: NewSearchService(), }) So(err, ShouldNotBeNil) }) diff --git a/pkg/cli/client/server_info_cmd.go b/pkg/cli/client/server_info_cmd.go index 041c0af6..eafda476 100644 --- a/pkg/cli/client/server_info_cmd.go +++ b/pkg/cli/client/server_info_cmd.go @@ -58,7 +58,8 @@ func GetServerStatus(config SearchConfig) error { return err } - _, err = makeGETRequest(ctx, checkAPISupportEndpoint, username, password, config.VerifyTLS, config.Debug, + _, err = config.SearchService.getHTTPClient().makeGETRequest( + ctx, checkAPISupportEndpoint, username, password, config.VerifyTLS, config.Debug, nil, config.ResultWriter) if err != nil { serverInfo := ServerInfo{} @@ -87,7 +88,8 @@ func GetServerStatus(config SearchConfig) error { serverInfo := ServerInfo{} - _, err = makeGETRequest(ctx, mgmtEndpoint, username, password, config.VerifyTLS, config.Debug, + _, err = config.SearchService.getHTTPClient().makeGETRequest( + ctx, mgmtEndpoint, username, password, config.VerifyTLS, config.Debug, &serverInfo, config.ResultWriter) switch { diff --git a/pkg/cli/client/server_info_cmd_test.go b/pkg/cli/client/server_info_cmd_test.go index f87ce46e..370165b2 100644 --- a/pkg/cli/client/server_info_cmd_test.go +++ b/pkg/cli/client/server_info_cmd_test.go @@ -112,15 +112,17 @@ func TestServerStatusCommandErrors(t *testing.T) { // invalid URL err = GetServerStatus(SearchConfig{ - ServURL: "a: ds", - ResultWriter: os.Stdout, + ServURL: "a: ds", + ResultWriter: os.Stdout, + SearchService: NewSearchService(), }) So(err, ShouldNotBeNil) // fail Get request err = GetServerStatus(SearchConfig{ - ServURL: "http://127.0.0.1:8000", - ResultWriter: os.Stdout, + ServURL: "http://127.0.0.1:8000", + ResultWriter: os.Stdout, + SearchService: NewSearchService(), }) So(err, ShouldBeNil) }) @@ -129,7 +131,7 @@ func TestServerStatusCommandErrors(t *testing.T) { port := test.GetFreePort() result := bytes.NewBuffer([]byte{}) searchConfig := SearchConfig{ - SearchService: mockService{}, + SearchService: newMockService(), ServURL: fmt.Sprintf("http://127.0.0.1:%v", port), User: "", OutputFormat: "text", diff --git a/pkg/cli/client/service.go b/pkg/cli/client/service.go index 02d38d1a..2ab766c1 100644 --- a/pkg/cli/client/service.go +++ b/pkg/cli/client/service.go @@ -64,6 +64,7 @@ type SearchService interface { //nolint:interfacebloat channel chan stringResult, wtgrp *sync.WaitGroup) getReferrers(ctx context.Context, config SearchConfig, username, password string, repo, digest string, ) (referrersResult, error) + getHTTPClient() *HTTPClient } type SearchConfig struct { @@ -80,13 +81,23 @@ type SearchConfig struct { Spinner spinnerState } -type searchService struct{} - -func NewSearchService() SearchService { - return searchService{} +type searchService struct { + httpClient *HTTPClient } -func (service searchService) getDerivedImageListGQL(ctx context.Context, config SearchConfig, username, password string, +func NewSearchService() SearchService { + return &searchService{ + httpClient: NewHTTPClient(), + } +} + +// getHTTPClient returns the HTTP client manager for this service instance. +func (service *searchService) getHTTPClient() *HTTPClient { + return service.httpClient +} + +func (service *searchService) getDerivedImageListGQL( + ctx context.Context, config SearchConfig, username, password string, derivedImage string, ) (*common.DerivedImageListResponse, error) { query := fmt.Sprintf(` @@ -122,7 +133,7 @@ func (service searchService) getDerivedImageListGQL(ctx context.Context, config return result, nil } -func (service searchService) getReferrersGQL(ctx context.Context, config SearchConfig, username, password string, +func (service *searchService) getReferrersGQL(ctx context.Context, config SearchConfig, username, password string, repo, digest string, ) (*common.ReferrersResp, error) { query := fmt.Sprintf(` @@ -149,7 +160,7 @@ func (service searchService) getReferrersGQL(ctx context.Context, config SearchC return result, nil } -func (service searchService) getCVEDiffListGQL(ctx context.Context, config SearchConfig, username, password string, +func (service *searchService) getCVEDiffListGQL(ctx context.Context, config SearchConfig, username, password string, minuend, subtrahend ImageIdentifier, ) (*cveDiffListResp, error) { minuendInput := getImageInput(minuend) @@ -189,7 +200,7 @@ func getImageInput(img ImageIdentifier) string { return fmt.Sprintf(`{Repo: "%s", Tag: "%s", Digest: "%s"%s}`, img.Repo, img.Tag, img.Digest, platform) } -func (service searchService) globalSearchGQL(ctx context.Context, config SearchConfig, username, password string, +func (service *searchService) globalSearchGQL(ctx context.Context, config SearchConfig, username, password string, query string, ) (*common.GlobalSearch, error) { GQLQuery := fmt.Sprintf(` @@ -234,7 +245,7 @@ func (service searchService) globalSearchGQL(ctx context.Context, config SearchC return &result.GlobalSearch, nil } -func (service searchService) getBaseImageListGQL(ctx context.Context, config SearchConfig, username, password string, +func (service *searchService) getBaseImageListGQL(ctx context.Context, config SearchConfig, username, password string, baseImage string, ) (*common.BaseImageListResponse, error) { query := fmt.Sprintf(` @@ -270,7 +281,7 @@ func (service searchService) getBaseImageListGQL(ctx context.Context, config Sea return result, nil } -func (service searchService) getImagesGQL(ctx context.Context, config SearchConfig, username, password string, +func (service *searchService) getImagesGQL(ctx context.Context, config SearchConfig, username, password string, imageName string, ) (*common.ImageListResponse, error) { query := fmt.Sprintf(` @@ -305,7 +316,7 @@ func (service searchService) getImagesGQL(ctx context.Context, config SearchConf return result, nil } -func (service searchService) getImagesForDigestGQL(ctx context.Context, config SearchConfig, username, password string, +func (service *searchService) getImagesForDigestGQL(ctx context.Context, config SearchConfig, username, password string, digest string, ) (*common.ImagesForDigest, error) { query := fmt.Sprintf(` @@ -340,7 +351,7 @@ func (service searchService) getImagesForDigestGQL(ctx context.Context, config S return result, nil } -func (service searchService) getCveByImageGQL(ctx context.Context, config SearchConfig, username, password, +func (service *searchService) getCveByImageGQL(ctx context.Context, config SearchConfig, username, password, imageName, searchedCVE string, ) (*cveResult, error) { query := fmt.Sprintf(` @@ -366,7 +377,7 @@ func (service searchService) getCveByImageGQL(ctx context.Context, config Search return result, nil } -func (service searchService) getTagsForCVEGQL(ctx context.Context, config SearchConfig, +func (service *searchService) getTagsForCVEGQL(ctx context.Context, config SearchConfig, username, password, repo, cveID string, ) (*common.ImagesForCve, error) { query := fmt.Sprintf(` @@ -416,7 +427,7 @@ func (service searchService) getTagsForCVEGQL(ctx context.Context, config Search return filteredResults, nil } -func (service searchService) getFixedTagsForCVEGQL(ctx context.Context, config SearchConfig, +func (service *searchService) getFixedTagsForCVEGQL(ctx context.Context, config SearchConfig, username, password, imageName, cveID string, ) (*common.ImageListWithCVEFixedResponse, error) { query := fmt.Sprintf(` @@ -453,7 +464,7 @@ func (service searchService) getFixedTagsForCVEGQL(ctx context.Context, config S return result, nil } -func (service searchService) getReferrers(ctx context.Context, config SearchConfig, username, password string, +func (service *searchService) getReferrers(ctx context.Context, config SearchConfig, username, password string, repo, digest string, ) (referrersResult, error) { referrersEndpoint, err := combineServerAndEndpointURL(config.ServURL, @@ -468,7 +479,7 @@ func (service searchService) getReferrers(ctx context.Context, config SearchConf referrerResp := &ispec.Index{} - _, err = makeGETRequest(ctx, referrersEndpoint, username, password, config.VerifyTLS, + _, err = service.httpClient.makeGETRequest(ctx, referrersEndpoint, username, password, config.VerifyTLS, config.Debug, &referrerResp, config.ResultWriter) if err != nil { if common.IsContextDone(ctx) { @@ -492,7 +503,7 @@ func (service searchService) getReferrers(ctx context.Context, config SearchConf return referrersList, nil } -func (service searchService) getImageByName(ctx context.Context, config SearchConfig, +func (service *searchService) getImageByName(ctx context.Context, config SearchConfig, username, password, imageName string, rch chan stringResult, wtgrp *sync.WaitGroup, ) { defer wtgrp.Done() @@ -506,12 +517,12 @@ func (service searchService) getImageByName(ctx context.Context, config SearchCo go rlim.startRateLimiter(ctx) localWg.Add(1) - go getImage(ctx, config, username, password, imageName, rch, &localWg, rlim) + go service.getImage(ctx, config, username, password, imageName, rch, &localWg, rlim) localWg.Wait() } -func (service searchService) getAllImages(ctx context.Context, config SearchConfig, username, password string, +func (service *searchService) getAllImages(ctx context.Context, config SearchConfig, username, password string, rch chan stringResult, wtgrp *sync.WaitGroup, ) { defer wtgrp.Done() @@ -530,7 +541,7 @@ func (service searchService) getAllImages(ctx context.Context, config SearchConf return } - _, err = makeGETRequest(ctx, catalogEndPoint, username, password, config.VerifyTLS, + _, err = service.httpClient.makeGETRequest(ctx, catalogEndPoint, username, password, config.VerifyTLS, config.Debug, catalog, config.ResultWriter) if err != nil { if common.IsContextDone(ctx) { @@ -552,13 +563,13 @@ func (service searchService) getAllImages(ctx context.Context, config SearchConf for _, repo := range catalog.Repositories { localWg.Add(1) - go getImage(ctx, config, username, password, repo, rch, &localWg, rlim) + go service.getImage(ctx, config, username, password, repo, rch, &localWg, rlim) } localWg.Wait() } -func getImage(ctx context.Context, config SearchConfig, username, password, imageName string, +func (service *searchService) getImage(ctx context.Context, config SearchConfig, username, password, imageName string, rch chan stringResult, wtgrp *sync.WaitGroup, pool *requestsPool, ) { defer wtgrp.Done() @@ -577,7 +588,7 @@ func getImage(ctx context.Context, config SearchConfig, username, password, imag tagList := &tagListResp{} - _, err = makeGETRequest(ctx, tagListEndpoint, username, password, config.VerifyTLS, + _, err = service.httpClient.makeGETRequest(ctx, tagListEndpoint, username, password, config.VerifyTLS, config.Debug, &tagList, config.ResultWriter) if err != nil { if common.IsContextDone(ctx) { @@ -612,7 +623,7 @@ func getImage(ctx context.Context, config SearchConfig, username, password, imag } } -func (service searchService) getImagesByDigest(ctx context.Context, config SearchConfig, username, +func (service *searchService) getImagesByDigest(ctx context.Context, config SearchConfig, username, password string, digest string, rch chan stringResult, wtgrp *sync.WaitGroup, ) { defer wtgrp.Done() @@ -687,7 +698,7 @@ func (service searchService) getImagesByDigest(ctx context.Context, config Searc // Query using GQL, the query string is passed as a parameter // errors are returned in the stringResult channel, the unmarshalled payload is in resultPtr. -func (service searchService) makeGraphQLQuery(ctx context.Context, +func (service *searchService) makeGraphQLQuery(ctx context.Context, config SearchConfig, username, password, query string, resultPtr any, ) error { @@ -696,7 +707,7 @@ func (service searchService) makeGraphQLQuery(ctx context.Context, return err } - err = makeGraphQLRequest(ctx, endPoint, query, username, password, config.VerifyTLS, + err = service.httpClient.makeGraphQLRequest(ctx, endPoint, query, username, password, config.VerifyTLS, config.Debug, resultPtr, config.ResultWriter) if err != nil { return err @@ -1361,7 +1372,7 @@ func getCVETableWriter(writer io.Writer) *tablewriter.Table { return table } -func (service searchService) getRepos(ctx context.Context, config SearchConfig, username, password string, +func (service *searchService) getRepos(ctx context.Context, config SearchConfig, username, password string, rch chan stringResult, wtgrp *sync.WaitGroup, ) { defer wtgrp.Done() @@ -1380,7 +1391,7 @@ func (service searchService) getRepos(ctx context.Context, config SearchConfig, return } - _, err = makeGETRequest(ctx, catalogEndPoint, username, password, config.VerifyTLS, + _, err = service.httpClient.makeGETRequest(ctx, catalogEndPoint, username, password, config.VerifyTLS, config.Debug, catalog, config.ResultWriter) if err != nil { if common.IsContextDone(ctx) { diff --git a/pkg/cli/client/utils.go b/pkg/cli/client/utils.go index d264c70f..d1a6f73f 100644 --- a/pkg/cli/client/utils.go +++ b/pkg/cli/client/utils.go @@ -37,7 +37,8 @@ func fetchImageDigest(repo, ref, username, password string, config SearchConfig) return "", err } - res, err := makeHEADRequest(context.Background(), url, username, password, config.VerifyTLS, false) + res, err := config.SearchService.getHTTPClient().makeHEADRequest( + context.Background(), url, username, password, config.VerifyTLS, false) digestStr := res.Get(constants.DistContentDigestKey) diff --git a/pkg/cli/client/utils_internal_test.go b/pkg/cli/client/utils_internal_test.go index e7aec78e..8c91a978 100644 --- a/pkg/cli/client/utils_internal_test.go +++ b/pkg/cli/client/utils_internal_test.go @@ -26,12 +26,13 @@ func getDefaultSearchConf(baseURL string) SearchConfig { outputFormat := "text" return SearchConfig{ - ServURL: baseURL, - ResultWriter: io.Discard, - VerifyTLS: verifyTLS, - Debug: debug, - Verbose: verbose, - OutputFormat: outputFormat, + ServURL: baseURL, + ResultWriter: io.Discard, + VerifyTLS: verifyTLS, + Debug: debug, + Verbose: verbose, + OutputFormat: outputFormat, + SearchService: NewSearchService(), } } @@ -88,7 +89,8 @@ func TestDoHTTPRequest(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, nil) So(err, ShouldBeNil) - So(func() { _, _ = doHTTPRequest(req, false, false, nil, io.Discard) }, ShouldNotPanic) + httpClient := NewHTTPClient() + So(func() { _, _ = httpClient.doHTTPRequest(req, false, false, nil, io.Discard) }, ShouldNotPanic) }) Convey("doHTTPRequest bad return json", t, func() { @@ -112,21 +114,25 @@ func TestDoHTTPRequest(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil) So(err, ShouldBeNil) - So(func() { _, _ = doHTTPRequest(req, false, false, &ispec.Manifest{}, io.Discard) }, ShouldNotPanic) + httpClient := NewHTTPClient() + So(func() { _, _ = httpClient.doHTTPRequest(req, false, false, &ispec.Manifest{}, io.Discard) }, ShouldNotPanic) }) Convey("makeGraphQLRequest bad request context", t, func() { - err := makeGraphQLRequest(nil, "", "", "", "", false, false, nil, io.Discard) //nolint:staticcheck + httpClient := NewHTTPClient() + err := httpClient.makeGraphQLRequest(nil, "", "", "", "", false, false, nil, io.Discard) //nolint:staticcheck So(err, ShouldNotBeNil) }) Convey("makeHEADRequest bad request context", t, func() { - _, err := makeHEADRequest(nil, "", "", "", false, false) //nolint:staticcheck + httpClient := NewHTTPClient() + _, err := httpClient.makeHEADRequest(nil, "", "", "", false, false) //nolint:staticcheck So(err, ShouldNotBeNil) }) Convey("makeGETRequest bad request context", t, func() { - _, err := makeGETRequest(nil, "", "", "", false, false, nil, io.Discard) //nolint:staticcheck + httpClient := NewHTTPClient() + _, err := httpClient.makeGETRequest(nil, "", "", "", false, false, nil, io.Discard) //nolint:staticcheck So(err, ShouldNotBeNil) }) diff --git a/pkg/common/http_client_test.go b/pkg/common/http_client_test.go index 14c38891..f95d5963 100644 --- a/pkg/common/http_client_test.go +++ b/pkg/common/http_client_test.go @@ -6,13 +6,48 @@ import ( "os" "path" "testing" + "time" . "github.com/smartystreets/goconvey/convey" "zotregistry.dev/zot/v2/pkg/common" - test "zotregistry.dev/zot/v2/pkg/test/common" + tlsutils "zotregistry.dev/zot/v2/pkg/test/tls" ) +// setupTestCerts generates CA and client certificates for testing. +func setupTestCerts(t *testing.T, tempDir string) { + t.Helper() + + // Generate CA certificate + caOpts := &tlsutils.CertificateOptions{ + CommonName: "*", + NotAfter: time.Now().AddDate(10, 0, 0), + } + caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts) + if err != nil { + t.Fatalf("Failed to generate CA cert: %v", err) + } + + caCertPath := path.Join(tempDir, "ca.crt") + err = os.WriteFile(caCertPath, caCertPEM, 0o600) + if err != nil { + t.Fatalf("Failed to write CA cert: %v", err) + } + + // Generate client certificate + clientCertPath := path.Join(tempDir, "client.cert") + clientKeyPath := path.Join(tempDir, "client.key") + clientOpts := &tlsutils.CertificateOptions{ + CommonName: "testclient", + OrganizationalUnit: "TestClient", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateClientCertToFile(caCertPEM, caKeyPEM, clientCertPath, clientKeyPath, clientOpts) + if err != nil { + t.Fatalf("Failed to generate client cert: %v", err) + } +} + func TestHTTPClient(t *testing.T) { Convey("test getTLSConfig()", t, func() { caCertPool, _ := x509.SystemCertPool() @@ -21,8 +56,7 @@ func TestHTTPClient(t *testing.T) { So(err, ShouldNotBeNil) tempDir := t.TempDir() - err = test.CopyTestKeysAndCerts(tempDir) - So(err, ShouldBeNil) + setupTestCerts(t, tempDir) err = os.Chmod(path.Join(tempDir, "ca.crt"), 0o000) So(err, ShouldBeNil) _, err = common.GetTLSConfig(tempDir, caCertPool) @@ -31,9 +65,8 @@ func TestHTTPClient(t *testing.T) { Convey("test CreateHTTPClient() no permissions on certificate", t, func() { tempDir := t.TempDir() - err := test.CopyTestKeysAndCerts(tempDir) - So(err, ShouldBeNil) - err = os.Chmod(path.Join(tempDir, "ca.crt"), 0o000) + setupTestCerts(t, tempDir) + err := os.Chmod(path.Join(tempDir, "ca.crt"), 0o000) So(err, ShouldBeNil) _, err = common.CreateHTTPClient(&common.HTTPClientOptions{ @@ -51,9 +84,8 @@ func TestHTTPClient(t *testing.T) { Convey("test CreateHTTPClient() no permissions on key", t, func() { tempDir := t.TempDir() - err := test.CopyTestKeysAndCerts(tempDir) - So(err, ShouldBeNil) - err = os.Chmod(path.Join(tempDir, "client.key"), 0o000) + setupTestCerts(t, tempDir) + err := os.Chmod(path.Join(tempDir, "client.key"), 0o000) So(err, ShouldBeNil) _, err = common.CreateHTTPClient(&common.HTTPClientOptions{ @@ -76,10 +108,9 @@ func TestHTTPClient(t *testing.T) { Convey("test CreateHTTPClient() with only client cert configured", t, func() { tempDir := t.TempDir() - err := test.CopyTestKeysAndCerts(tempDir) - So(err, ShouldBeNil) + setupTestCerts(t, tempDir) - _, err = common.CreateHTTPClient(&common.HTTPClientOptions{ + _, err := common.CreateHTTPClient(&common.HTTPClientOptions{ TLSEnabled: true, VerifyTLS: true, Host: "localhost", @@ -92,10 +123,9 @@ func TestHTTPClient(t *testing.T) { Convey("test CreateHTTPClient() with only client key configured", t, func() { tempDir := t.TempDir() - err := test.CopyTestKeysAndCerts(tempDir) - So(err, ShouldBeNil) + setupTestCerts(t, tempDir) - _, err = common.CreateHTTPClient(&common.HTTPClientOptions{ + _, err := common.CreateHTTPClient(&common.HTTPClientOptions{ TLSEnabled: true, VerifyTLS: true, Host: "localhost", @@ -108,8 +138,7 @@ func TestHTTPClient(t *testing.T) { Convey("test CreateHTTPClient() with full certificate config", t, func() { tempDir := t.TempDir() - err := test.CopyTestKeysAndCerts(tempDir) - So(err, ShouldBeNil) + setupTestCerts(t, tempDir) client, err := common.CreateHTTPClient(&common.HTTPClientOptions{ TLSEnabled: true, @@ -131,8 +160,7 @@ func TestHTTPClient(t *testing.T) { Convey("test CreateHTTPClient() with no TLS verify", t, func() { tempDir := t.TempDir() - err := test.CopyTestKeysAndCerts(tempDir) - So(err, ShouldBeNil) + setupTestCerts(t, tempDir) client, err := common.CreateHTTPClient(&common.HTTPClientOptions{ TLSEnabled: true, @@ -155,8 +183,7 @@ func TestHTTPClient(t *testing.T) { Convey("test CreateHTTPClient() with no TLS, but TLS verify enabled", t, func() { tempDir := t.TempDir() - err := test.CopyTestKeysAndCerts(tempDir) - So(err, ShouldBeNil) + setupTestCerts(t, tempDir) client, err := common.CreateHTTPClient(&common.HTTPClientOptions{ TLSEnabled: false, diff --git a/pkg/extensions/extensions_test.go b/pkg/extensions/extensions_test.go index dd2df33d..31471cd4 100644 --- a/pkg/extensions/extensions_test.go +++ b/pkg/extensions/extensions_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "net/url" "os" + "path" "testing" "time" @@ -22,12 +23,41 @@ import ( syncconf "zotregistry.dev/zot/v2/pkg/extensions/config/sync" authutils "zotregistry.dev/zot/v2/pkg/test/auth" test "zotregistry.dev/zot/v2/pkg/test/common" + tlsutils "zotregistry.dev/zot/v2/pkg/test/tls" ) -const ( - ServerCert = "../../test/data/server.cert" - ServerKey = "../../test/data/server.key" -) +// setupTestServerCerts generates CA and server certificates for testing. +// Returns paths to server certificate and key files. +func setupTestServerCerts(t *testing.T) (string, string) { + t.Helper() + tempDir := t.TempDir() + + // Generate CA certificate + caOpts := &tlsutils.CertificateOptions{ + CommonName: "*", + NotAfter: time.Now().AddDate(10, 0, 0), + } + caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts) + if err != nil { + t.Fatalf("Failed to generate CA cert: %v", err) + } + + // Generate server certificate + serverCertPath := path.Join(tempDir, "server.cert") + serverKeyPath := path.Join(tempDir, "server.key") + serverOpts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "*", + OrganizationalUnit: "TestServer", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts) + if err != nil { + t.Fatalf("Failed to generate server cert: %v", err) + } + + return serverCertPath, serverKeyPath +} func TestEnableExtension(t *testing.T) { Convey("Verify log if sync disabled in config", t, func() { @@ -861,14 +891,17 @@ func TestMgmtWithBearer(t *testing.T) { for _, testCase := range testCases { Convey("Make a new controller with "+testCase.name, t, func() { + // Generate certificates dynamically for the test + serverCertPath, serverKeyPath := setupTestServerCerts(t) + authorizedNamespace := "allowedrepo" unauthorizedNamespace := "notallowedrepo" var authTestServer *httptest.Server if testCase.useLegacyAuthTestServer { - authTestServer = authutils.MakeAuthTestServerLegacy(ServerKey, unauthorizedNamespace) + authTestServer = authutils.MakeAuthTestServerLegacy(serverKeyPath, unauthorizedNamespace) } else { - authTestServer = authutils.MakeAuthTestServer(ServerKey, "RS256", unauthorizedNamespace) + authTestServer = authutils.MakeAuthTestServer(serverKeyPath, "RS256", unauthorizedNamespace) } defer authTestServer.Close() @@ -883,7 +916,7 @@ func TestMgmtWithBearer(t *testing.T) { conf.HTTP.Auth = &config.AuthConfig{ Bearer: &config.BearerConfig{ - Cert: ServerCert, + Cert: serverCertPath, Realm: authTestServer.URL + "/auth/token", Service: aurl.Host, }, diff --git a/pkg/extensions/sync/sync_test.go b/pkg/extensions/sync/sync_test.go index 66690865..a6816b3a 100644 --- a/pkg/extensions/sync/sync_test.go +++ b/pkg/extensions/sync/sync_test.go @@ -53,6 +53,7 @@ import ( "zotregistry.dev/zot/v2/pkg/test/mocks" ociutils "zotregistry.dev/zot/v2/pkg/test/oci-utils" "zotregistry.dev/zot/v2/pkg/test/signature" + tlsutils "zotregistry.dev/zot/v2/pkg/test/tls" ) const ( @@ -60,11 +61,6 @@ const ( dockerIndexManifestMediaType = "application/vnd.docker.distribution.manifest.list.v2+json" dockerManifestConfigMediaType = "application/vnd.docker.container.image.v1+json" dockerLayerMediaType = "application/vnd.docker.image.rootfs.diff.tar.gzip" - ServerCert = "../../../test/data/server.cert" - ServerKey = "../../../test/data/server.key" - CACert = "../../../test/data/ca.crt" - ClientCert = "../../../test/data/client.cert" - ClientKey = "../../../test/data/client.key" testImage = "zot-test" testImageTag = "0.0.1" @@ -96,8 +92,62 @@ type catalog struct { Repositories []string `json:"repositories"` } -func makeUpstreamServer( - t *testing.T, secure, basicAuth bool, +// setupTestCertsForSync generates certificates for sync tests that need file paths. +func setupTestCertsForSync(t *testing.T, tempDir string) ( + string, string, string, string, string, []byte, +) { + t.Helper() + + // Generate CA certificate + caOpts := &tlsutils.CertificateOptions{ + CommonName: "*", + } + caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts) + if err != nil { + t.Fatalf("Failed to generate CA cert: %v", err) + } + + caCertPath := path.Join(tempDir, "ca.crt") + caKeyPath := path.Join(tempDir, "ca.key") + err = os.WriteFile(caCertPath, caCertPEM, 0o600) + if err != nil { + t.Fatalf("Failed to write CA cert: %v", err) + } + _ = os.WriteFile(caKeyPath, caKeyPEM, 0o600) + + // Generate server certificate (10 years validity, matching gen_certs.sh) + serverCertPath := path.Join(tempDir, "server.cert") + serverKeyPath := path.Join(tempDir, "server.key") + serverOpts := &tlsutils.CertificateOptions{ + Hostname: "127.0.0.1", + CommonName: "*", + OrganizationalUnit: "TestServer", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, serverCertPath, serverKeyPath, serverOpts) + if err != nil { + t.Fatalf("Failed to generate server cert: %v", err) + } + + // Generate client certificate (10 years validity, matching gen_certs.sh) + clientCertPath := path.Join(tempDir, "client.cert") + clientKeyPath := path.Join(tempDir, "client.key") + clientOpts := &tlsutils.CertificateOptions{ + CommonName: "testclient", + OrganizationalUnit: "TestClient", + NotAfter: time.Now().AddDate(10, 0, 0), + } + err = tlsutils.GenerateClientCertToFile(caCertPEM, caKeyPEM, clientCertPath, clientKeyPath, clientOpts) + if err != nil { + t.Fatalf("Failed to generate client cert: %v", err) + } + + return caCertPath, serverCertPath, serverKeyPath, clientCertPath, clientKeyPath, caCertPEM +} + +// makeUpstreamServerWithCerts creates an upstream server using shared certificates. +func makeUpstreamServerWithCerts( + t *testing.T, secure, basicAuth bool, certDir string, caCertPEM []byte, ) (*api.Controller, string, string, *resty.Client) { t.Helper() @@ -109,25 +159,27 @@ func makeUpstreamServer( if secure { srcBaseURL = test.GetSecureBaseURL(srcPort) - srcConfig.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, - CACert: CACert, - } + // Use shared certificates + caCertPath := path.Join(certDir, "ca.crt") + serverCertPath := path.Join(certDir, "server.cert") + serverKeyPath := path.Join(certDir, "server.key") + clientCertPath := path.Join(certDir, "client.cert") + clientKeyPath := path.Join(certDir, "client.key") - caCert, err := os.ReadFile(CACert) - if err != nil { - panic(err) + srcConfig.HTTP.TLS = &config.TLSConfig{ + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) client.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}) - cert, err := tls.LoadX509KeyPair(ClientCert, ClientKey) + cert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) if err != nil { - panic(err) + t.Fatalf("Failed to load client cert for upstream test client: %v", err) } client.SetCertificates(cert) @@ -174,8 +226,25 @@ func makeUpstreamServer( return sctlr, srcBaseURL, srcDir, client } -func makeDownstreamServer( - t *testing.T, secure bool, syncConfig *syncconf.Config, +func makeUpstreamServer( + t *testing.T, secure, basicAuth bool, +) (*api.Controller, string, string, *resty.Client) { + t.Helper() + + // Generate certificates and delegate to makeUpstreamServerWithCerts + if secure { + tempDir := t.TempDir() + _, _, _, _, _, caCertPEM := setupTestCertsForSync(t, tempDir) + + return makeUpstreamServerWithCerts(t, secure, basicAuth, tempDir, caCertPEM) + } + + return makeUpstreamServerWithCerts(t, secure, basicAuth, "", nil) +} + +// makeDownstreamServerWithCerts creates a downstream server using shared certificates. +func makeDownstreamServerWithCerts( + t *testing.T, secure bool, syncConfig *syncconf.Config, certDir string, caCertPEM []byte, ) (*api.Controller, string, string, *resty.Client) { t.Helper() @@ -187,25 +256,27 @@ func makeDownstreamServer( if secure { destBaseURL = test.GetSecureBaseURL(destPort) - destConfig.HTTP.TLS = &config.TLSConfig{ - Cert: ServerCert, - Key: ServerKey, - CACert: CACert, - } + // Use shared certificates (same CA as upstream) + caCertPath := path.Join(certDir, "ca.crt") + serverCertPath := path.Join(certDir, "server.cert") + serverKeyPath := path.Join(certDir, "server.key") + clientCertPath := path.Join(certDir, "client.cert") + clientKeyPath := path.Join(certDir, "client.key") - caCert, err := os.ReadFile(CACert) - if err != nil { - panic(err) + destConfig.HTTP.TLS = &config.TLSConfig{ + Cert: serverCertPath, + Key: serverKeyPath, + CACert: caCertPath, } caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool.AppendCertsFromPEM(caCertPEM) client.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}) - cert, err := tls.LoadX509KeyPair(ClientCert, ClientKey) + cert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) if err != nil { - panic(err) + t.Fatalf("Failed to load client cert for downstream test client: %v", err) } client.SetCertificates(cert) @@ -235,6 +306,22 @@ func makeDownstreamServer( return dctlr, destBaseURL, destDir, client } +func makeDownstreamServer( + t *testing.T, secure bool, syncConfig *syncconf.Config, +) (*api.Controller, string, string, *resty.Client) { + t.Helper() + + // Generate certificates and delegate to makeDownstreamServerWithCerts + if secure { + tempDir := t.TempDir() + _, _, _, _, _, caCertPEM := setupTestCertsForSync(t, tempDir) + + return makeDownstreamServerWithCerts(t, secure, syncConfig, tempDir, caCertPEM) + } + + return makeDownstreamServerWithCerts(t, secure, syncConfig, "", nil) +} + func makeInsecureDownstreamServerFixedPort( t *testing.T, port string, syncConfig *syncconf.Config, clusterConfig *config.ClusterConfig, ) (*api.Controller, string, string, *resty.Client) { @@ -2027,6 +2114,13 @@ func TestPermsDenied(t *testing.T) { err = os.Chmod(syncSubDir, 0o000) So(err, ShouldBeNil) + // Ensure permissions are restored on cleanup to allow temp directory removal + defer func() { + _ = os.Chmod(syncSubDir, 0o755) + // Also restore permissions on parent directory in case it was affected + _ = os.Chmod(path.Join(destDir, testImage), 0o755) + }() + dcm.StartAndWait(destPort) found, err := test.ReadLogFileAndSearchString(dctlr.Config.Log.Output, @@ -2500,7 +2594,13 @@ func TestTLS(t *testing.T) { Convey("Verify sync TLS feature", t, func() { updateDuration, _ := time.ParseDuration("1h") - sctlr, srcBaseURL, srcDir, _ := makeUpstreamServer(t, true, false) + // Generate shared CA and certificates BEFORE creating servers + // This ensures all certificates are signed by the same CA + sharedCertDir := t.TempDir() + caCertPath, _, _, clientCertPath, clientKeyPath, caCertPEM := setupTestCertsForSync(t, sharedCertDir) + + // Create upstream server with shared certificates + sctlr, srcBaseURL, srcDir, _ := makeUpstreamServerWithCerts(t, true, false, sharedCertDir, caCertPEM) scm := test.NewControllerManager(sctlr) scm.StartAndWait(sctlr.Config.HTTP.Port) @@ -2521,28 +2621,37 @@ func TestTLS(t *testing.T) { panic(err) } - // copy upstream client certs, use them in sync config + // Use the same client certificates for sync (signed by the same CA as upstream server) destClientCertDir := t.TempDir() + // Copy client cert and key to sync config directory + destClientCertPath := path.Join(destClientCertDir, "client.cert") + destClientKeyPath := path.Join(destClientCertDir, "client.key") + destCACertPath := path.Join(destClientCertDir, "ca.crt") - destFilePath := path.Join(destClientCertDir, "ca.crt") - - err = test.CopyFile(CACert, destFilePath) + clientCertData, err := os.ReadFile(clientCertPath) if err != nil { - panic(err) + t.Fatalf("Failed to read client cert: %v", err) + } + clientKeyData, err := os.ReadFile(clientKeyPath) + if err != nil { + t.Fatalf("Failed to read client key: %v", err) + } + caCertData, err := os.ReadFile(caCertPath) + if err != nil { + t.Fatalf("Failed to read CA cert: %v", err) } - destFilePath = path.Join(destClientCertDir, "client.cert") - - err = test.CopyFile(ClientCert, destFilePath) + err = os.WriteFile(destClientCertPath, clientCertData, 0o600) if err != nil { - panic(err) + t.Fatalf("Failed to write client cert: %v", err) } - - destFilePath = path.Join(destClientCertDir, "client.key") - - err = test.CopyFile(ClientKey, destFilePath) + err = os.WriteFile(destClientKeyPath, clientKeyData, 0o600) if err != nil { - panic(err) + t.Fatalf("Failed to write client key: %v", err) + } + err = os.WriteFile(destCACertPath, caCertData, 0o600) + if err != nil { + t.Fatalf("Failed to write CA cert: %v", err) } regex := ".*" @@ -2574,7 +2683,9 @@ func TestTLS(t *testing.T) { Registries: []syncconf.RegistryConfig{syncRegistryConfig}, } - dctlr, destBaseURL, destDir, destClient := makeDownstreamServer(t, true, syncConfig) + // Create downstream server with shared certificates (same CA as upstream) + dctlr, destBaseURL, destDir, destClient := makeDownstreamServerWithCerts( + t, true, syncConfig, sharedCertDir, caCertPEM) dcm := test.NewControllerManager(dctlr) dcm.StartAndWait(dctlr.Config.HTTP.Port) @@ -2634,11 +2745,15 @@ func TestBearerAuth(t *testing.T) { // a repo for which clients do not have access, sync shouldn't be able to sync it unauthorizedNamespace := testCveImage + // Generate certificates for bearer auth + tempDir := t.TempDir() + _, serverCertPath, serverKeyPath, _, _, _ := setupTestCertsForSync(t, tempDir) + var authTestServer *httptest.Server if testCase.useLegacyAuthTestServer { - authTestServer = authutils.MakeAuthTestServerLegacy(ServerKey, unauthorizedNamespace) + authTestServer = authutils.MakeAuthTestServerLegacy(serverKeyPath, unauthorizedNamespace) } else { - authTestServer = authutils.MakeAuthTestServer(ServerKey, "RS256", unauthorizedNamespace) + authTestServer = authutils.MakeAuthTestServer(serverKeyPath, "RS256", unauthorizedNamespace) } defer authTestServer.Close() @@ -2649,7 +2764,7 @@ func TestBearerAuth(t *testing.T) { sctlr.Config.HTTP.Auth = &config.AuthConfig{ Bearer: &config.BearerConfig{ - Cert: ServerCert, + Cert: serverCertPath, Realm: authTestServer.URL + "/auth/token", Service: aurl.Host, }, @@ -2788,11 +2903,15 @@ func TestBearerAuth(t *testing.T) { // a repo for which clients do not have access, sync shouldn't be able to sync it unauthorizedNamespace := testCveImage + // Generate certificates for bearer auth + tempDir := t.TempDir() + _, serverCertPath, serverKeyPath, _, _, _ := setupTestCertsForSync(t, tempDir) + var authTestServer *httptest.Server if testCase.useLegacyAuthTestServer { - authTestServer = authutils.MakeAuthTestServerLegacy(ServerKey, unauthorizedNamespace) + authTestServer = authutils.MakeAuthTestServerLegacy(serverKeyPath, unauthorizedNamespace) } else { - authTestServer = authutils.MakeAuthTestServer(ServerKey, "RS256", unauthorizedNamespace) + authTestServer = authutils.MakeAuthTestServer(serverKeyPath, "RS256", unauthorizedNamespace) } defer authTestServer.Close() @@ -2803,7 +2922,7 @@ func TestBearerAuth(t *testing.T) { sctlr.Config.HTTP.Auth = &config.AuthConfig{ Bearer: &config.BearerConfig{ - Cert: ServerCert, + Cert: serverCertPath, Realm: authTestServer.URL + "/auth/token", Service: aurl.Host, }, @@ -3557,14 +3676,11 @@ func TestInvalidCerts(t *testing.T) { // copy client certs, use them in sync config clientCertDir := t.TempDir() - destFilePath := path.Join(clientCertDir, "ca.crt") + // Generate certificates + caCertPath, _, _, _, _, _ := setupTestCertsForSync(t, clientCertDir) - err := test.CopyFile(CACert, destFilePath) - if err != nil { - panic(err) - } - - dstfile, err := os.OpenFile(destFilePath, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0o600) + // Modify the CA cert file to add invalid text for testing + dstfile, err := os.OpenFile(caCertPath, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0o600) if err != nil { panic(err) } @@ -3575,20 +3691,6 @@ func TestInvalidCerts(t *testing.T) { panic(err) } - destFilePath = path.Join(clientCertDir, "client.cert") - - err = test.CopyFile(ClientCert, destFilePath) - if err != nil { - panic(err) - } - - destFilePath = path.Join(clientCertDir, "client.key") - - err = test.CopyFile(ClientKey, destFilePath) - if err != nil { - panic(err) - } - tlsVerify := true syncRegistryConfig := syncconf.RegistryConfig{ @@ -3626,33 +3728,15 @@ func TestInvalidCerts(t *testing.T) { func TestCertsWithWrongPerms(t *testing.T) { Convey("Verify sync with wrong permissions on certs", t, func() { updateDuration, _ := time.ParseDuration("1h") - // copy client certs, use them in sync config + // Generate certificates and copy them to sync config directory clientCertDir := t.TempDir() - destFilePath := path.Join(clientCertDir, "ca.crt") + caCertPath, _, _, _, _, _ := setupTestCertsForSync(t, clientCertDir) - err := test.CopyFile(CACert, destFilePath) - if err != nil { - panic(err) - } - - err = os.Chmod(destFilePath, 0o000) + // Change permissions on CA cert for testing + err := os.Chmod(caCertPath, 0o000) So(err, ShouldBeNil) - destFilePath = path.Join(clientCertDir, "client.cert") - - err = test.CopyFile(ClientCert, destFilePath) - if err != nil { - panic(err) - } - - destFilePath = path.Join(clientCertDir, "client.key") - - err = test.CopyFile(ClientKey, destFilePath) - if err != nil { - panic(err) - } - tlsVerify := true syncRegistryConfig := syncconf.RegistryConfig{ diff --git a/pkg/test/tls/tls.go b/pkg/test/tls/tls.go index a776b99d..5c2b1d21 100644 --- a/pkg/test/tls/tls.go +++ b/pkg/test/tls/tls.go @@ -1,6 +1,9 @@ package tls import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -14,11 +17,25 @@ import ( ) var ( - ErrDecodeCAPEM = errors.New("failed to decode CA certificate PEM") - ErrInvalidCertificateType = errors.New("invalid certificate type") - ErrCertificateOptionsRequired = errors.New("CertificateOptions is required") - ErrHostnameRequired = errors.New("Hostname is required in CertificateOptions") - ErrNoCertificatesProvided = errors.New("at least one certificate is required") + ErrDecodeCAPEM = errors.New("failed to decode CA certificate PEM") + ErrInvalidCertificateType = errors.New("invalid certificate type") + ErrHostnameRequired = errors.New("Hostname is required in CertificateOptions") + ErrNoCertificatesProvided = errors.New("at least one certificate is required") + ErrInvalidKeyType = errors.New("invalid key type") + ErrUnsupportedPrivateKeyType = errors.New("unsupported private key type") + ErrFailedParsePrivateKey = errors.New("failed to parse private key: unsupported key format") + ErrFailedDecodeCertPEM = errors.New("failed to decode certificate PEM") + ErrFailedDecodeKeyPEM = errors.New("failed to decode private key PEM") + ErrPrivateKeyNotRSA = errors.New("private key is not RSA") +) + +// KeyType represents the type of cryptographic key to use for certificate generation. +type KeyType string + +const ( + KeyTypeRSA KeyType = "RSA" + KeyTypeECDSA KeyType = "ECDSA" + KeyTypeED25519 KeyType = "ED25519" ) const ( @@ -55,23 +72,68 @@ type CertificateOptions struct { // based on whether it's a valid IP address or a DNS name. Hostname string - // CommonName is the CommonName (CN) for client certificates. + // CommonName is the CommonName (CN) for certificates. // For client certificates, this is optional - if not provided, the certificate will not have a CN. CommonName string + + // OrganizationalUnit is the OrganizationalUnit (OU) for certificates. + // If not provided, the certificate will not have an OU. + OrganizationalUnit string + + // KeyType specifies the type of cryptographic key to use. + // Valid values: "RSA" (default), "ECDSA", "ED25519". + // If empty or "RSA", RSA keys will be generated. + KeyType KeyType } // generateCertificate is a helper function that generates a certificate and private key. // If signerCert and signerKey are nil, the certificate will be self-signed. +// signerKey can be *rsa.PrivateKey, *ecdsa.PrivateKey, or ed25519.PrivateKey. func generateCertificate( certType string, opts *CertificateOptions, signerCert *x509.Certificate, - signerKey *rsa.PrivateKey, + signerKey any, // Can be *rsa.PrivateKey, *ecdsa.PrivateKey, or ed25519.PrivateKey ) ([]byte, []byte, error) { - // Generate private key - privKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, fmt.Errorf("failed to generate private key: %w", err) + var ( + issuerCert *x509.Certificate + issuerKey any + privKey any + publicKey any + err error + ) + + // Determine key type + keyType := KeyTypeRSA + if opts != nil && opts.KeyType != "" { + keyType = opts.KeyType + } + + // Generate private key based on key type + switch keyType { + case KeyTypeRSA: + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate RSA private key: %w", err) + } + privKey = rsaKey + publicKey = &rsaKey.PublicKey + case KeyTypeECDSA: + ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate ECDSA private key: %w", err) + } + privKey = ecKey + publicKey = &ecKey.PublicKey + case KeyTypeED25519: + edPublicKey, edKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate ED25519 private key: %w", err) + } + privKey = edKey + publicKey = edPublicKey + default: + return nil, nil, fmt.Errorf("%w: %s", ErrInvalidKeyType, keyType) } // Initialize certificate template @@ -84,9 +146,7 @@ func generateCertificate( applyOptions(template, opts, certType) // Determine signer (self-signed if signerCert is nil) - var issuerCert *x509.Certificate - var issuerKey *rsa.PrivateKey if signerCert == nil { // Self-signed issuerCert = template @@ -98,7 +158,7 @@ func generateCertificate( } // Create the certificate - certDER, err := x509.CreateCertificate(rand.Reader, template, issuerCert, &privKey.PublicKey, issuerKey) + certDER, err := x509.CreateCertificate(rand.Reader, template, issuerCert, publicKey, issuerKey) if err != nil { return nil, nil, fmt.Errorf("failed to create certificate: %w", err) } @@ -109,17 +169,114 @@ func generateCertificate( Bytes: certDER, }) - // Encode private key to PEM - keyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(privKey), - }) + // Encode private key to PEM based on key type + var keyPEM []byte + + switch privKeyType := privKey.(type) { + case *rsa.PrivateKey: + keyPEM = pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privKeyType), + }) + case *ecdsa.PrivateKey: + keyBytes, err := x509.MarshalECPrivateKey(privKeyType) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal ECDSA private key: %w", err) + } + keyPEM = pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keyBytes, + }) + case ed25519.PrivateKey: + keyBytes, err := x509.MarshalPKCS8PrivateKey(privKeyType) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal ED25519 private key: %w", err) + } + keyPEM = pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: keyBytes, + }) + default: + return nil, nil, fmt.Errorf("%w: %T", ErrUnsupportedPrivateKeyType, privKey) + } return certPEM, keyPEM, nil } +// parsePrivateKeyFromPEM parses a private key from PEM-encoded bytes. +// Tries PKCS8 first (handles RSA, ECDSA, and ED25519), then falls back to PKCS1 (RSA) and EC SEC1 (ECDSA). +func parsePrivateKeyFromPEM(keyBytes []byte) (any, error) { + // Try PKCS8 first (handles RSA, ECDSA, and ED25519) + if privKey, err := x509.ParsePKCS8PrivateKey(keyBytes); err == nil { + return privKey, nil + } + + // Fall back to PKCS1 (RSA only) + if rsaKey, err := x509.ParsePKCS1PrivateKey(keyBytes); err == nil { + return rsaKey, nil + } + + // Fall back to EC SEC1 format + if ecKey, err := x509.ParseECPrivateKey(keyBytes); err == nil { + return ecKey, nil + } + + return nil, ErrFailedParsePrivateKey +} + +// ExtractPublicKeyFromCert extracts the public key from a certificate in PEM format. +// Returns the public key in PKIX format (suitable for ECDSA and ED25519). +func ExtractPublicKeyFromCert(certPEM []byte) ([]byte, error) { + block, _ := pem.Decode(certPEM) + if block == nil { + return nil, ErrFailedDecodeCertPEM + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + publicKeyBytes, err := x509.MarshalPKIXPublicKey(cert.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to marshal public key: %w", err) + } + + return pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyBytes, + }), nil +} + +// ExtractRSAPublicKeyPKCS1 extracts the RSA public key from a private key in PEM format. +// Returns the public key in PKCS1 format (RSA-specific). +func ExtractRSAPublicKeyPKCS1(keyPEM []byte) ([]byte, error) { + block, _ := pem.Decode(keyPEM) + if block == nil { + return nil, ErrFailedDecodeKeyPEM + } + + privKey, err := parsePrivateKeyFromPEM(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse RSA private key: %w", err) + } + + rsaKey, ok := privKey.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("%w, got %T", ErrPrivateKeyNotRSA, privKey) + } + + publicKeyBytes := x509.MarshalPKCS1PublicKey(&rsaKey.PublicKey) + + return pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: publicKeyBytes, + }), nil +} + // parseCA parses CA certificate and private key from PEM format. -func parseCA(caCertPEM, caKeyPEM []byte) (*x509.Certificate, *rsa.PrivateKey, error) { +// Returns the certificate and the private key (which can be *rsa.PrivateKey, *ecdsa.PrivateKey, or ed25519.PrivateKey). +func parseCA(caCertPEM, caKeyPEM []byte) (*x509.Certificate, any, error) { // Parse CA certificate caCertBlock, _ := pem.Decode(caCertPEM) if caCertBlock == nil { @@ -137,7 +294,7 @@ func parseCA(caCertPEM, caKeyPEM []byte) (*x509.Certificate, *rsa.PrivateKey, er return nil, nil, ErrDecodeCAPEM } - caPrivKey, err := x509.ParsePKCS1PrivateKey(caKeyBlock.Bytes) + caPrivKey, err := parsePrivateKeyFromPEM(caKeyBlock.Bytes) if err != nil { return nil, nil, fmt.Errorf("failed to parse CA private key: %w", err) } @@ -165,8 +322,7 @@ func initializeTemplate(certType string) (*x509.Certificate, error) { StreetAddress: []string{""}, PostalCode: []string{""}, } - template.NotBefore = time.Now() - template.NotAfter = time.Now().AddDate(10, 0, 0) // 10 years for CA + // NotBefore and NotAfter are set via CertificateOptions in test logic case certTypeServer: template.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} template.KeyUsage = x509.KeyUsageDigitalSignature @@ -178,8 +334,7 @@ func initializeTemplate(certType string) (*x509.Certificate, error) { StreetAddress: []string{""}, PostalCode: []string{""}, } - template.NotBefore = time.Now() - template.NotAfter = time.Now().AddDate(1, 0, 0) // 1 year for server + // NotBefore and NotAfter are set via CertificateOptions in test logic template.IPAddresses = []net.IP{net.ParseIP("127.0.0.1")} // Default IP for Server case certTypeClient: template.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} @@ -192,8 +347,7 @@ func initializeTemplate(certType string) (*x509.Certificate, error) { StreetAddress: []string{""}, PostalCode: []string{""}, } - template.NotBefore = time.Now() - template.NotAfter = time.Now().AddDate(1, 0, 0) // 1 year for client + // NotBefore and NotAfter are set via CertificateOptions in test logic default: return nil, fmt.Errorf("%w: %s", ErrInvalidCertificateType, certType) } @@ -208,14 +362,18 @@ func applyOptions(template *x509.Certificate, opts *CertificateOptions, certType opts = &CertificateOptions{} } - // Apply NotBefore if provided in options + // Apply NotBefore - default to time.Now() if not provided if !opts.NotBefore.IsZero() { template.NotBefore = opts.NotBefore + } else { + template.NotBefore = time.Now() } - // Apply NotAfter if provided in options + // Apply NotAfter - default to 1 year if not provided, matching gen_certs.sh if !opts.NotAfter.IsZero() { template.NotAfter = opts.NotAfter + } else { + template.NotAfter = time.Now().AddDate(1, 0, 0) } // Apply SAN (Subject Alternative Name) - handle IPAddresses @@ -246,12 +404,19 @@ func applyOptions(template *x509.Certificate, opts *CertificateOptions, certType template.EmailAddresses = opts.EmailAddresses } - // Apply CommonName - explicitly set to empty string if not provided to ensure it's empty + // Apply CommonName - if provided, override the default; otherwise keep default from initializeTemplate if opts.CommonName != "" { template.Subject.CommonName = opts.CommonName - } else { + } else if opts != nil && opts.CommonName == "" && certType == certTypeClient { + // Special case: For client certs, if opts is provided and CommonName is explicitly set to empty, + // use empty CN (for noidentity-style certs) template.Subject.CommonName = "" } + + // Apply OrganizationalUnit - if provided, set it + if opts.OrganizationalUnit != "" { + template.Subject.OrganizationalUnit = []string{opts.OrganizationalUnit} + } } // GenerateCACert generates a CA certificate and private key. diff --git a/pkg/test/tls/tls_test.go b/pkg/test/tls/tls_test.go index e5bceba7..613bcdcc 100644 --- a/pkg/test/tls/tls_test.go +++ b/pkg/test/tls/tls_test.go @@ -4,6 +4,7 @@ import ( "crypto/x509" "encoding/pem" "net" + "os" "path" "testing" "time" @@ -434,5 +435,270 @@ func TestErrorPaths(t *testing.T) { err := tls.GenerateClientSelfSignedCertToFile(certPath, keyPath, nil) So(err, ShouldNotBeNil) }) + + Convey("Test generateCertificate with invalid key type", func() { + // This tests the default case in generateCertificate switch + invalidKeyType := tls.KeyType("INVALID") + opts := &tls.CertificateOptions{ + KeyType: invalidKeyType, + } + _, _, err := tls.GenerateCACert(opts) + So(err, ShouldNotBeNil) + }) + + Convey("Test generateCertificate with ECDSA key type", func() { + // Test that ECDSA key generation works correctly + caCertPEM, caKeyPEM, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + opts := &tls.CertificateOptions{ + Hostname: "localhost", + KeyType: tls.KeyTypeECDSA, + } + certPEM, keyPEM, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + So(certPEM, ShouldNotBeNil) + So(keyPEM, ShouldNotBeNil) + + // Verify ECDSA key was generated + keyBlock, _ := pem.Decode(keyPEM) + So(keyBlock, ShouldNotBeNil) + So(keyBlock.Type, ShouldEqual, "EC PRIVATE KEY") + }) + + Convey("Test generateCertificate with ED25519 key type", func() { + caCertPEM, caKeyPEM, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + opts := &tls.CertificateOptions{ + Hostname: "localhost", + KeyType: tls.KeyTypeED25519, + } + certPEM, keyPEM, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + So(certPEM, ShouldNotBeNil) + So(keyPEM, ShouldNotBeNil) + + // Verify ED25519 key was generated + keyBlock, _ := pem.Decode(keyPEM) + So(keyBlock, ShouldNotBeNil) + So(keyBlock.Type, ShouldEqual, "PRIVATE KEY") + }) + + Convey("Test parsePrivateKeyFromPEM with PKCS8 format", func() { + // Generate a certificate with ED25519 (uses PKCS8) + opts := &tls.CertificateOptions{ + KeyType: tls.KeyTypeED25519, + } + _, keyPEM, err := tls.GenerateCACert(opts) + So(err, ShouldBeNil) + + // Parse it back - should work with PKCS8 + keyBlock, _ := pem.Decode(keyPEM) + So(keyBlock, ShouldNotBeNil) + + // This tests the PKCS8 path in parsePrivateKeyFromPEM + _, err = x509.ParsePKCS8PrivateKey(keyBlock.Bytes) + So(err, ShouldBeNil) + }) + + Convey("Test parsePrivateKeyFromPEM with EC SEC1 format", func() { + // Generate a certificate with ECDSA (uses SEC1) + opts := &tls.CertificateOptions{ + KeyType: tls.KeyTypeECDSA, + } + _, keyPEM, err := tls.GenerateCACert(opts) + So(err, ShouldBeNil) + + // Parse it back - should work + keyBlock, _ := pem.Decode(keyPEM) + So(keyBlock, ShouldNotBeNil) + + // This tests the EC SEC1 path in parsePrivateKeyFromPEM + _, err = x509.ParseECPrivateKey(keyBlock.Bytes) + So(err, ShouldBeNil) + }) + }) +} + +func TestExtractPublicKeyFromCert(t *testing.T) { + Convey("Test ExtractPublicKeyFromCert", t, func() { + caCertPEM, _, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + Convey("Extract public key from valid certificate", func() { + publicKeyPEM, err := tls.ExtractPublicKeyFromCert(caCertPEM) + So(err, ShouldBeNil) + So(publicKeyPEM, ShouldNotBeNil) + + // Verify it's valid PEM + block, _ := pem.Decode(publicKeyPEM) + So(block, ShouldNotBeNil) + So(block.Type, ShouldEqual, "PUBLIC KEY") + }) + + Convey("Extract public key from invalid PEM", func() { + invalidPEM := []byte("not a valid PEM") + _, err := tls.ExtractPublicKeyFromCert(invalidPEM) + So(err, ShouldEqual, tls.ErrFailedDecodeCertPEM) + }) + + Convey("Extract public key from server certificate", func() { + caCertPEM, caKeyPEM, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + opts := &tls.CertificateOptions{ + Hostname: "localhost", + } + serverCertPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + publicKeyPEM, err := tls.ExtractPublicKeyFromCert(serverCertPEM) + So(err, ShouldBeNil) + So(publicKeyPEM, ShouldNotBeNil) + }) + + Convey("Extract public key from ECDSA certificate", func() { + caOpts := &tls.CertificateOptions{ + KeyType: tls.KeyTypeECDSA, + } + caCertPEM, caKeyPEM, err := tls.GenerateCACert(caOpts) + So(err, ShouldBeNil) + + opts := &tls.CertificateOptions{ + Hostname: "localhost", + KeyType: tls.KeyTypeECDSA, + } + serverCertPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + publicKeyPEM, err := tls.ExtractPublicKeyFromCert(serverCertPEM) + So(err, ShouldBeNil) + So(publicKeyPEM, ShouldNotBeNil) + }) + + Convey("Extract public key from ED25519 certificate", func() { + caOpts := &tls.CertificateOptions{ + KeyType: tls.KeyTypeED25519, + } + caCertPEM, caKeyPEM, err := tls.GenerateCACert(caOpts) + So(err, ShouldBeNil) + + opts := &tls.CertificateOptions{ + Hostname: "localhost", + KeyType: tls.KeyTypeED25519, + } + serverCertPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts) + So(err, ShouldBeNil) + + publicKeyPEM, err := tls.ExtractPublicKeyFromCert(serverCertPEM) + So(err, ShouldBeNil) + So(publicKeyPEM, ShouldNotBeNil) + }) + + Convey("Extract public key from certificate with invalid certificate data", func() { + // Create a PEM block with invalid certificate data + invalidCertPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: []byte("invalid certificate data"), + }) + + _, err := tls.ExtractPublicKeyFromCert(invalidCertPEM) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldContainSubstring, "failed to parse certificate") + }) + }) +} + +func TestExtractRSAPublicKeyPKCS1(t *testing.T) { + Convey("Test ExtractRSAPublicKeyPKCS1", t, func() { + _, keyPEM, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + Convey("Extract RSA public key in PKCS1 format", func() { + publicKeyPEM, err := tls.ExtractRSAPublicKeyPKCS1(keyPEM) + So(err, ShouldBeNil) + So(publicKeyPEM, ShouldNotBeNil) + + // Verify it's valid PEM + block, _ := pem.Decode(publicKeyPEM) + So(block, ShouldNotBeNil) + So(block.Type, ShouldEqual, "RSA PUBLIC KEY") + }) + + Convey("Extract RSA public key from invalid PEM", func() { + invalidPEM := []byte("not a valid PEM") + _, err := tls.ExtractRSAPublicKeyPKCS1(invalidPEM) + So(err, ShouldEqual, tls.ErrFailedDecodeKeyPEM) + }) + + Convey("Extract RSA public key from non-RSA key", func() { + opts := &tls.CertificateOptions{ + KeyType: tls.KeyTypeECDSA, + } + _, ecdsaKeyPEM, err := tls.GenerateCACert(opts) + So(err, ShouldBeNil) + + _, err = tls.ExtractRSAPublicKeyPKCS1(ecdsaKeyPEM) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldContainSubstring, "private key is not RSA") + }) + }) +} + +func TestWriteCertificateChainToFile(t *testing.T) { + Convey("Test WriteCertificateChainToFile", t, func() { + Convey("Write certificate chain with multiple certificates", func() { + tempDir := t.TempDir() + chainPath := path.Join(tempDir, "chain.crt") + + // Generate root CA + rootCACert, rootCAKey, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + // Generate intermediate CA + intermediateCACert, _, err := tls.GenerateIntermediateCACert(rootCACert, rootCAKey) + So(err, ShouldBeNil) + + // Generate leaf certificate + leafCert, _, err := tls.GenerateClientCert(rootCACert, rootCAKey, nil) + So(err, ShouldBeNil) + + // Write chain (leaf first, then intermediate) + err = tls.WriteCertificateChainToFile(chainPath, leafCert, intermediateCACert, rootCACert) + So(err, ShouldBeNil) + + // Verify file was created + chainData, err := os.ReadFile(chainPath) + So(err, ShouldBeNil) + So(len(chainData), ShouldBeGreaterThan, 0) + + // Verify it contains all certificates + So(string(chainData), ShouldContainSubstring, "BEGIN CERTIFICATE") + }) + + Convey("Write certificate chain with no certificates", func() { + tempDir := t.TempDir() + chainPath := path.Join(tempDir, "chain.crt") + + err := tls.WriteCertificateChainToFile(chainPath) + So(err, ShouldEqual, tls.ErrNoCertificatesProvided) + }) + + Convey("Write certificate chain with single certificate", func() { + tempDir := t.TempDir() + chainPath := path.Join(tempDir, "chain.crt") + + cert, _, err := tls.GenerateCACert() + So(err, ShouldBeNil) + + err = tls.WriteCertificateChainToFile(chainPath, cert) + So(err, ShouldBeNil) + + // Verify file was created + chainData, err := os.ReadFile(chainPath) + So(err, ShouldBeNil) + So(len(chainData), ShouldBeGreaterThan, 0) + }) }) }