refactor: enhance TLS cert generation and refactor HTTP client architecture (#3638)

- Refactored HTTP client from global cache to struct-based approach (global state was shared between tests, including what certificates to use)
- Enhanced pkg/test/tls to support ECDSA and ED25519 key types
- Replaced static certificate files with dynamic generation in golang tests
- Fixed test cleanup issues and improved resource management

This eliminates dependency on external cert generation scripts and
improves test maintainability.

Signed-off-by: Andrei Aaron <andreifdaaron@gmail.com>
This commit is contained in:
Andrei Aaron
2025-12-13 09:47:32 +02:00
committed by GitHub
parent 1447bb24b4
commit cf8b0bdbf9
22 changed files with 1590 additions and 554 deletions
+196 -31
View File
@@ -1,6 +1,9 @@
package tls
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
@@ -14,11 +17,25 @@ import (
)
var (
ErrDecodeCAPEM = errors.New("failed to decode CA certificate PEM")
ErrInvalidCertificateType = errors.New("invalid certificate type")
ErrCertificateOptionsRequired = errors.New("CertificateOptions is required")
ErrHostnameRequired = errors.New("Hostname is required in CertificateOptions")
ErrNoCertificatesProvided = errors.New("at least one certificate is required")
ErrDecodeCAPEM = errors.New("failed to decode CA certificate PEM")
ErrInvalidCertificateType = errors.New("invalid certificate type")
ErrHostnameRequired = errors.New("Hostname is required in CertificateOptions")
ErrNoCertificatesProvided = errors.New("at least one certificate is required")
ErrInvalidKeyType = errors.New("invalid key type")
ErrUnsupportedPrivateKeyType = errors.New("unsupported private key type")
ErrFailedParsePrivateKey = errors.New("failed to parse private key: unsupported key format")
ErrFailedDecodeCertPEM = errors.New("failed to decode certificate PEM")
ErrFailedDecodeKeyPEM = errors.New("failed to decode private key PEM")
ErrPrivateKeyNotRSA = errors.New("private key is not RSA")
)
// KeyType represents the type of cryptographic key to use for certificate generation.
type KeyType string
const (
KeyTypeRSA KeyType = "RSA"
KeyTypeECDSA KeyType = "ECDSA"
KeyTypeED25519 KeyType = "ED25519"
)
const (
@@ -55,23 +72,68 @@ type CertificateOptions struct {
// based on whether it's a valid IP address or a DNS name.
Hostname string
// CommonName is the CommonName (CN) for client certificates.
// CommonName is the CommonName (CN) for certificates.
// For client certificates, this is optional - if not provided, the certificate will not have a CN.
CommonName string
// OrganizationalUnit is the OrganizationalUnit (OU) for certificates.
// If not provided, the certificate will not have an OU.
OrganizationalUnit string
// KeyType specifies the type of cryptographic key to use.
// Valid values: "RSA" (default), "ECDSA", "ED25519".
// If empty or "RSA", RSA keys will be generated.
KeyType KeyType
}
// generateCertificate is a helper function that generates a certificate and private key.
// If signerCert and signerKey are nil, the certificate will be self-signed.
// signerKey can be *rsa.PrivateKey, *ecdsa.PrivateKey, or ed25519.PrivateKey.
func generateCertificate(
certType string,
opts *CertificateOptions,
signerCert *x509.Certificate,
signerKey *rsa.PrivateKey,
signerKey any, // Can be *rsa.PrivateKey, *ecdsa.PrivateKey, or ed25519.PrivateKey
) ([]byte, []byte, error) {
// Generate private key
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
var (
issuerCert *x509.Certificate
issuerKey any
privKey any
publicKey any
err error
)
// Determine key type
keyType := KeyTypeRSA
if opts != nil && opts.KeyType != "" {
keyType = opts.KeyType
}
// Generate private key based on key type
switch keyType {
case KeyTypeRSA:
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate RSA private key: %w", err)
}
privKey = rsaKey
publicKey = &rsaKey.PublicKey
case KeyTypeECDSA:
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate ECDSA private key: %w", err)
}
privKey = ecKey
publicKey = &ecKey.PublicKey
case KeyTypeED25519:
edPublicKey, edKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate ED25519 private key: %w", err)
}
privKey = edKey
publicKey = edPublicKey
default:
return nil, nil, fmt.Errorf("%w: %s", ErrInvalidKeyType, keyType)
}
// Initialize certificate template
@@ -84,9 +146,7 @@ func generateCertificate(
applyOptions(template, opts, certType)
// Determine signer (self-signed if signerCert is nil)
var issuerCert *x509.Certificate
var issuerKey *rsa.PrivateKey
if signerCert == nil {
// Self-signed
issuerCert = template
@@ -98,7 +158,7 @@ func generateCertificate(
}
// Create the certificate
certDER, err := x509.CreateCertificate(rand.Reader, template, issuerCert, &privKey.PublicKey, issuerKey)
certDER, err := x509.CreateCertificate(rand.Reader, template, issuerCert, publicKey, issuerKey)
if err != nil {
return nil, nil, fmt.Errorf("failed to create certificate: %w", err)
}
@@ -109,17 +169,114 @@ func generateCertificate(
Bytes: certDER,
})
// Encode private key to PEM
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privKey),
})
// Encode private key to PEM based on key type
var keyPEM []byte
switch privKeyType := privKey.(type) {
case *rsa.PrivateKey:
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privKeyType),
})
case *ecdsa.PrivateKey:
keyBytes, err := x509.MarshalECPrivateKey(privKeyType)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal ECDSA private key: %w", err)
}
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "EC PRIVATE KEY",
Bytes: keyBytes,
})
case ed25519.PrivateKey:
keyBytes, err := x509.MarshalPKCS8PrivateKey(privKeyType)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal ED25519 private key: %w", err)
}
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: keyBytes,
})
default:
return nil, nil, fmt.Errorf("%w: %T", ErrUnsupportedPrivateKeyType, privKey)
}
return certPEM, keyPEM, nil
}
// parsePrivateKeyFromPEM parses a private key from PEM-encoded bytes.
// Tries PKCS8 first (handles RSA, ECDSA, and ED25519), then falls back to PKCS1 (RSA) and EC SEC1 (ECDSA).
func parsePrivateKeyFromPEM(keyBytes []byte) (any, error) {
// Try PKCS8 first (handles RSA, ECDSA, and ED25519)
if privKey, err := x509.ParsePKCS8PrivateKey(keyBytes); err == nil {
return privKey, nil
}
// Fall back to PKCS1 (RSA only)
if rsaKey, err := x509.ParsePKCS1PrivateKey(keyBytes); err == nil {
return rsaKey, nil
}
// Fall back to EC SEC1 format
if ecKey, err := x509.ParseECPrivateKey(keyBytes); err == nil {
return ecKey, nil
}
return nil, ErrFailedParsePrivateKey
}
// ExtractPublicKeyFromCert extracts the public key from a certificate in PEM format.
// Returns the public key in PKIX format (suitable for ECDSA and ED25519).
func ExtractPublicKeyFromCert(certPEM []byte) ([]byte, error) {
block, _ := pem.Decode(certPEM)
if block == nil {
return nil, ErrFailedDecodeCertPEM
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate: %w", err)
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
return pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKeyBytes,
}), nil
}
// ExtractRSAPublicKeyPKCS1 extracts the RSA public key from a private key in PEM format.
// Returns the public key in PKCS1 format (RSA-specific).
func ExtractRSAPublicKeyPKCS1(keyPEM []byte) ([]byte, error) {
block, _ := pem.Decode(keyPEM)
if block == nil {
return nil, ErrFailedDecodeKeyPEM
}
privKey, err := parsePrivateKeyFromPEM(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse RSA private key: %w", err)
}
rsaKey, ok := privKey.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("%w, got %T", ErrPrivateKeyNotRSA, privKey)
}
publicKeyBytes := x509.MarshalPKCS1PublicKey(&rsaKey.PublicKey)
return pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: publicKeyBytes,
}), nil
}
// parseCA parses CA certificate and private key from PEM format.
func parseCA(caCertPEM, caKeyPEM []byte) (*x509.Certificate, *rsa.PrivateKey, error) {
// Returns the certificate and the private key (which can be *rsa.PrivateKey, *ecdsa.PrivateKey, or ed25519.PrivateKey).
func parseCA(caCertPEM, caKeyPEM []byte) (*x509.Certificate, any, error) {
// Parse CA certificate
caCertBlock, _ := pem.Decode(caCertPEM)
if caCertBlock == nil {
@@ -137,7 +294,7 @@ func parseCA(caCertPEM, caKeyPEM []byte) (*x509.Certificate, *rsa.PrivateKey, er
return nil, nil, ErrDecodeCAPEM
}
caPrivKey, err := x509.ParsePKCS1PrivateKey(caKeyBlock.Bytes)
caPrivKey, err := parsePrivateKeyFromPEM(caKeyBlock.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse CA private key: %w", err)
}
@@ -165,8 +322,7 @@ func initializeTemplate(certType string) (*x509.Certificate, error) {
StreetAddress: []string{""},
PostalCode: []string{""},
}
template.NotBefore = time.Now()
template.NotAfter = time.Now().AddDate(10, 0, 0) // 10 years for CA
// NotBefore and NotAfter are set via CertificateOptions in test logic
case certTypeServer:
template.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
template.KeyUsage = x509.KeyUsageDigitalSignature
@@ -178,8 +334,7 @@ func initializeTemplate(certType string) (*x509.Certificate, error) {
StreetAddress: []string{""},
PostalCode: []string{""},
}
template.NotBefore = time.Now()
template.NotAfter = time.Now().AddDate(1, 0, 0) // 1 year for server
// NotBefore and NotAfter are set via CertificateOptions in test logic
template.IPAddresses = []net.IP{net.ParseIP("127.0.0.1")} // Default IP for Server
case certTypeClient:
template.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
@@ -192,8 +347,7 @@ func initializeTemplate(certType string) (*x509.Certificate, error) {
StreetAddress: []string{""},
PostalCode: []string{""},
}
template.NotBefore = time.Now()
template.NotAfter = time.Now().AddDate(1, 0, 0) // 1 year for client
// NotBefore and NotAfter are set via CertificateOptions in test logic
default:
return nil, fmt.Errorf("%w: %s", ErrInvalidCertificateType, certType)
}
@@ -208,14 +362,18 @@ func applyOptions(template *x509.Certificate, opts *CertificateOptions, certType
opts = &CertificateOptions{}
}
// Apply NotBefore if provided in options
// Apply NotBefore - default to time.Now() if not provided
if !opts.NotBefore.IsZero() {
template.NotBefore = opts.NotBefore
} else {
template.NotBefore = time.Now()
}
// Apply NotAfter if provided in options
// Apply NotAfter - default to 1 year if not provided, matching gen_certs.sh
if !opts.NotAfter.IsZero() {
template.NotAfter = opts.NotAfter
} else {
template.NotAfter = time.Now().AddDate(1, 0, 0)
}
// Apply SAN (Subject Alternative Name) - handle IPAddresses
@@ -246,12 +404,19 @@ func applyOptions(template *x509.Certificate, opts *CertificateOptions, certType
template.EmailAddresses = opts.EmailAddresses
}
// Apply CommonName - explicitly set to empty string if not provided to ensure it's empty
// Apply CommonName - if provided, override the default; otherwise keep default from initializeTemplate
if opts.CommonName != "" {
template.Subject.CommonName = opts.CommonName
} else {
} else if opts != nil && opts.CommonName == "" && certType == certTypeClient {
// Special case: For client certs, if opts is provided and CommonName is explicitly set to empty,
// use empty CN (for noidentity-style certs)
template.Subject.CommonName = ""
}
// Apply OrganizationalUnit - if provided, set it
if opts.OrganizationalUnit != "" {
template.Subject.OrganizationalUnit = []string{opts.OrganizationalUnit}
}
}
// GenerateCACert generates a CA certificate and private key.
+266
View File
@@ -4,6 +4,7 @@ import (
"crypto/x509"
"encoding/pem"
"net"
"os"
"path"
"testing"
"time"
@@ -434,5 +435,270 @@ func TestErrorPaths(t *testing.T) {
err := tls.GenerateClientSelfSignedCertToFile(certPath, keyPath, nil)
So(err, ShouldNotBeNil)
})
Convey("Test generateCertificate with invalid key type", func() {
// This tests the default case in generateCertificate switch
invalidKeyType := tls.KeyType("INVALID")
opts := &tls.CertificateOptions{
KeyType: invalidKeyType,
}
_, _, err := tls.GenerateCACert(opts)
So(err, ShouldNotBeNil)
})
Convey("Test generateCertificate with ECDSA key type", func() {
// Test that ECDSA key generation works correctly
caCertPEM, caKeyPEM, err := tls.GenerateCACert()
So(err, ShouldBeNil)
opts := &tls.CertificateOptions{
Hostname: "localhost",
KeyType: tls.KeyTypeECDSA,
}
certPEM, keyPEM, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts)
So(err, ShouldBeNil)
So(certPEM, ShouldNotBeNil)
So(keyPEM, ShouldNotBeNil)
// Verify ECDSA key was generated
keyBlock, _ := pem.Decode(keyPEM)
So(keyBlock, ShouldNotBeNil)
So(keyBlock.Type, ShouldEqual, "EC PRIVATE KEY")
})
Convey("Test generateCertificate with ED25519 key type", func() {
caCertPEM, caKeyPEM, err := tls.GenerateCACert()
So(err, ShouldBeNil)
opts := &tls.CertificateOptions{
Hostname: "localhost",
KeyType: tls.KeyTypeED25519,
}
certPEM, keyPEM, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts)
So(err, ShouldBeNil)
So(certPEM, ShouldNotBeNil)
So(keyPEM, ShouldNotBeNil)
// Verify ED25519 key was generated
keyBlock, _ := pem.Decode(keyPEM)
So(keyBlock, ShouldNotBeNil)
So(keyBlock.Type, ShouldEqual, "PRIVATE KEY")
})
Convey("Test parsePrivateKeyFromPEM with PKCS8 format", func() {
// Generate a certificate with ED25519 (uses PKCS8)
opts := &tls.CertificateOptions{
KeyType: tls.KeyTypeED25519,
}
_, keyPEM, err := tls.GenerateCACert(opts)
So(err, ShouldBeNil)
// Parse it back - should work with PKCS8
keyBlock, _ := pem.Decode(keyPEM)
So(keyBlock, ShouldNotBeNil)
// This tests the PKCS8 path in parsePrivateKeyFromPEM
_, err = x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
So(err, ShouldBeNil)
})
Convey("Test parsePrivateKeyFromPEM with EC SEC1 format", func() {
// Generate a certificate with ECDSA (uses SEC1)
opts := &tls.CertificateOptions{
KeyType: tls.KeyTypeECDSA,
}
_, keyPEM, err := tls.GenerateCACert(opts)
So(err, ShouldBeNil)
// Parse it back - should work
keyBlock, _ := pem.Decode(keyPEM)
So(keyBlock, ShouldNotBeNil)
// This tests the EC SEC1 path in parsePrivateKeyFromPEM
_, err = x509.ParseECPrivateKey(keyBlock.Bytes)
So(err, ShouldBeNil)
})
})
}
func TestExtractPublicKeyFromCert(t *testing.T) {
Convey("Test ExtractPublicKeyFromCert", t, func() {
caCertPEM, _, err := tls.GenerateCACert()
So(err, ShouldBeNil)
Convey("Extract public key from valid certificate", func() {
publicKeyPEM, err := tls.ExtractPublicKeyFromCert(caCertPEM)
So(err, ShouldBeNil)
So(publicKeyPEM, ShouldNotBeNil)
// Verify it's valid PEM
block, _ := pem.Decode(publicKeyPEM)
So(block, ShouldNotBeNil)
So(block.Type, ShouldEqual, "PUBLIC KEY")
})
Convey("Extract public key from invalid PEM", func() {
invalidPEM := []byte("not a valid PEM")
_, err := tls.ExtractPublicKeyFromCert(invalidPEM)
So(err, ShouldEqual, tls.ErrFailedDecodeCertPEM)
})
Convey("Extract public key from server certificate", func() {
caCertPEM, caKeyPEM, err := tls.GenerateCACert()
So(err, ShouldBeNil)
opts := &tls.CertificateOptions{
Hostname: "localhost",
}
serverCertPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts)
So(err, ShouldBeNil)
publicKeyPEM, err := tls.ExtractPublicKeyFromCert(serverCertPEM)
So(err, ShouldBeNil)
So(publicKeyPEM, ShouldNotBeNil)
})
Convey("Extract public key from ECDSA certificate", func() {
caOpts := &tls.CertificateOptions{
KeyType: tls.KeyTypeECDSA,
}
caCertPEM, caKeyPEM, err := tls.GenerateCACert(caOpts)
So(err, ShouldBeNil)
opts := &tls.CertificateOptions{
Hostname: "localhost",
KeyType: tls.KeyTypeECDSA,
}
serverCertPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts)
So(err, ShouldBeNil)
publicKeyPEM, err := tls.ExtractPublicKeyFromCert(serverCertPEM)
So(err, ShouldBeNil)
So(publicKeyPEM, ShouldNotBeNil)
})
Convey("Extract public key from ED25519 certificate", func() {
caOpts := &tls.CertificateOptions{
KeyType: tls.KeyTypeED25519,
}
caCertPEM, caKeyPEM, err := tls.GenerateCACert(caOpts)
So(err, ShouldBeNil)
opts := &tls.CertificateOptions{
Hostname: "localhost",
KeyType: tls.KeyTypeED25519,
}
serverCertPEM, _, err := tls.GenerateServerCert(caCertPEM, caKeyPEM, opts)
So(err, ShouldBeNil)
publicKeyPEM, err := tls.ExtractPublicKeyFromCert(serverCertPEM)
So(err, ShouldBeNil)
So(publicKeyPEM, ShouldNotBeNil)
})
Convey("Extract public key from certificate with invalid certificate data", func() {
// Create a PEM block with invalid certificate data
invalidCertPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: []byte("invalid certificate data"),
})
_, err := tls.ExtractPublicKeyFromCert(invalidCertPEM)
So(err, ShouldNotBeNil)
So(err.Error(), ShouldContainSubstring, "failed to parse certificate")
})
})
}
func TestExtractRSAPublicKeyPKCS1(t *testing.T) {
Convey("Test ExtractRSAPublicKeyPKCS1", t, func() {
_, keyPEM, err := tls.GenerateCACert()
So(err, ShouldBeNil)
Convey("Extract RSA public key in PKCS1 format", func() {
publicKeyPEM, err := tls.ExtractRSAPublicKeyPKCS1(keyPEM)
So(err, ShouldBeNil)
So(publicKeyPEM, ShouldNotBeNil)
// Verify it's valid PEM
block, _ := pem.Decode(publicKeyPEM)
So(block, ShouldNotBeNil)
So(block.Type, ShouldEqual, "RSA PUBLIC KEY")
})
Convey("Extract RSA public key from invalid PEM", func() {
invalidPEM := []byte("not a valid PEM")
_, err := tls.ExtractRSAPublicKeyPKCS1(invalidPEM)
So(err, ShouldEqual, tls.ErrFailedDecodeKeyPEM)
})
Convey("Extract RSA public key from non-RSA key", func() {
opts := &tls.CertificateOptions{
KeyType: tls.KeyTypeECDSA,
}
_, ecdsaKeyPEM, err := tls.GenerateCACert(opts)
So(err, ShouldBeNil)
_, err = tls.ExtractRSAPublicKeyPKCS1(ecdsaKeyPEM)
So(err, ShouldNotBeNil)
So(err.Error(), ShouldContainSubstring, "private key is not RSA")
})
})
}
func TestWriteCertificateChainToFile(t *testing.T) {
Convey("Test WriteCertificateChainToFile", t, func() {
Convey("Write certificate chain with multiple certificates", func() {
tempDir := t.TempDir()
chainPath := path.Join(tempDir, "chain.crt")
// Generate root CA
rootCACert, rootCAKey, err := tls.GenerateCACert()
So(err, ShouldBeNil)
// Generate intermediate CA
intermediateCACert, _, err := tls.GenerateIntermediateCACert(rootCACert, rootCAKey)
So(err, ShouldBeNil)
// Generate leaf certificate
leafCert, _, err := tls.GenerateClientCert(rootCACert, rootCAKey, nil)
So(err, ShouldBeNil)
// Write chain (leaf first, then intermediate)
err = tls.WriteCertificateChainToFile(chainPath, leafCert, intermediateCACert, rootCACert)
So(err, ShouldBeNil)
// Verify file was created
chainData, err := os.ReadFile(chainPath)
So(err, ShouldBeNil)
So(len(chainData), ShouldBeGreaterThan, 0)
// Verify it contains all certificates
So(string(chainData), ShouldContainSubstring, "BEGIN CERTIFICATE")
})
Convey("Write certificate chain with no certificates", func() {
tempDir := t.TempDir()
chainPath := path.Join(tempDir, "chain.crt")
err := tls.WriteCertificateChainToFile(chainPath)
So(err, ShouldEqual, tls.ErrNoCertificatesProvided)
})
Convey("Write certificate chain with single certificate", func() {
tempDir := t.TempDir()
chainPath := path.Join(tempDir, "chain.crt")
cert, _, err := tls.GenerateCACert()
So(err, ShouldBeNil)
err = tls.WriteCertificateChainToFile(chainPath, cert)
So(err, ShouldBeNil)
// Verify file was created
chainData, err := os.ReadFile(chainPath)
So(err, ShouldBeNil)
So(len(chainData), ShouldBeGreaterThan, 0)
})
})
}