mirror of
https://github.com/project-zot/zot.git
synced 2026-06-17 12:58:02 +08:00
9425ca8b7d
Validate callback_ui and default invalid values to /. Allow absolute callback_ui only when its origin is allowlisted via http.auth.openid.callbackAllowOrigins (and externalUrl). Add/adjust unit + controller tests and update examples/docs for relative vs allowlisted absolute redirect Signed-off-by: Andrei Aaron <andreifdaaron@gmail.com>
471 lines
13 KiB
Go
471 lines
13 KiB
Go
//go:build sync && scrub && metrics && search && lint && userprefs && mgmt && imagetrust && ui
|
|
|
|
package api
|
|
|
|
import (
|
|
goerrors "errors"
|
|
"net/url"
|
|
"os"
|
|
"path"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"zotregistry.dev/zot/v2/pkg/log"
|
|
tlsutils "zotregistry.dev/zot/v2/pkg/test/tls"
|
|
)
|
|
|
|
var errGetCertificateFailed = goerrors.New("GetCertificate failed")
|
|
|
|
func TestReloadCertificateStatFailureKeepsModTimes(t *testing.T) {
|
|
logger := log.NewLogger("debug", "")
|
|
tempDir := t.TempDir()
|
|
|
|
certPath := path.Join(tempDir, "cert.pem")
|
|
keyPath := path.Join(tempDir, "key.pem")
|
|
|
|
caOpts := &tlsutils.CertificateOptions{
|
|
CommonName: "*",
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
KeyType: tlsutils.KeyTypeECDSA,
|
|
}
|
|
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate CA cert: %v", err)
|
|
}
|
|
|
|
serverOpts := &tlsutils.CertificateOptions{
|
|
Hostname: "127.0.0.1",
|
|
CommonName: "*",
|
|
OrganizationalUnit: "TestServer",
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
KeyType: tlsutils.KeyTypeECDSA,
|
|
}
|
|
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
|
|
t.Fatalf("failed to generate server cert: %v", err)
|
|
}
|
|
|
|
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
|
|
|
|
if err := watcher.ReloadCertificate(); err != nil {
|
|
t.Fatalf("failed to load initial certificate: %v", err)
|
|
}
|
|
|
|
watcher.mu.RLock()
|
|
initialCertModTime := watcher.tlsCertModTime
|
|
initialKeyModTime := watcher.tlsKeyModTime
|
|
watcher.mu.RUnlock()
|
|
|
|
if initialCertModTime.IsZero() || initialKeyModTime.IsZero() {
|
|
t.Fatal("expected initial mod times to be set")
|
|
}
|
|
|
|
originalStat := tlsFileStat
|
|
tlsFileStat = func(string) (os.FileInfo, error) {
|
|
return nil, os.ErrNotExist
|
|
}
|
|
t.Cleanup(func() {
|
|
tlsFileStat = originalStat
|
|
})
|
|
|
|
time.Sleep(10 * time.Millisecond)
|
|
if err := watcher.ReloadCertificate(); err != nil {
|
|
t.Fatalf("unexpected reload error when stat fails: %v", err)
|
|
}
|
|
|
|
watcher.mu.RLock()
|
|
updatedCertModTime := watcher.tlsCertModTime
|
|
updatedKeyModTime := watcher.tlsKeyModTime
|
|
watcher.mu.RUnlock()
|
|
|
|
if !updatedCertModTime.Equal(initialCertModTime) {
|
|
t.Fatal("certificate mod time changed despite stat failure")
|
|
}
|
|
if !updatedKeyModTime.Equal(initialKeyModTime) {
|
|
t.Fatal("key mod time changed despite stat failure")
|
|
}
|
|
}
|
|
|
|
func TestGetCertificateReloadConcurrency(t *testing.T) {
|
|
logger := log.NewLogger("debug", "")
|
|
tempDir := t.TempDir()
|
|
|
|
certPath := path.Join(tempDir, "cert.pem")
|
|
keyPath := path.Join(tempDir, "key.pem")
|
|
|
|
caOpts := &tlsutils.CertificateOptions{
|
|
CommonName: "*",
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
KeyType: tlsutils.KeyTypeECDSA,
|
|
}
|
|
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate CA cert: %v", err)
|
|
}
|
|
|
|
serverOpts := &tlsutils.CertificateOptions{
|
|
Hostname: "127.0.0.1",
|
|
CommonName: "*",
|
|
OrganizationalUnit: "TestServer",
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
KeyType: tlsutils.KeyTypeECDSA,
|
|
}
|
|
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
|
|
t.Fatalf("failed to generate server cert: %v", err)
|
|
}
|
|
|
|
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
|
|
|
|
if err := watcher.ReloadCertificate(); err != nil {
|
|
t.Fatalf("failed to load initial certificate: %v", err)
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
errorCh := make(chan error, 32)
|
|
|
|
reloadWorker := func(iterations int) {
|
|
defer wg.Done()
|
|
for i := 0; i < iterations; i++ {
|
|
if err := watcher.ReloadCertificate(); err != nil {
|
|
errorCh <- err
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
getWorker := func(iterations int) {
|
|
defer wg.Done()
|
|
for i := 0; i < iterations; i++ {
|
|
cert, err := watcher.GetCertificate(nil)
|
|
if err != nil || cert == nil {
|
|
errorCh <- err
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
wg.Add(3)
|
|
go reloadWorker(50)
|
|
go getWorker(100)
|
|
go getWorker(100)
|
|
|
|
wg.Wait()
|
|
close(errorCh)
|
|
|
|
for err := range errorCh {
|
|
if err != nil {
|
|
t.Fatalf("concurrent TLS operations failed: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGetCertificateInitialLoadConcurrency(t *testing.T) {
|
|
logger := log.NewLogger("debug", "")
|
|
tempDir := t.TempDir()
|
|
|
|
certPath := path.Join(tempDir, "cert.pem")
|
|
keyPath := path.Join(tempDir, "key.pem")
|
|
|
|
caOpts := &tlsutils.CertificateOptions{
|
|
CommonName: "*",
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
KeyType: tlsutils.KeyTypeECDSA,
|
|
}
|
|
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate CA cert: %v", err)
|
|
}
|
|
|
|
serverOpts := &tlsutils.CertificateOptions{
|
|
Hostname: "127.0.0.1",
|
|
CommonName: "*",
|
|
OrganizationalUnit: "TestServer",
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
KeyType: tlsutils.KeyTypeECDSA,
|
|
}
|
|
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
|
|
t.Fatalf("failed to generate server cert: %v", err)
|
|
}
|
|
|
|
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
|
|
|
|
var wg sync.WaitGroup
|
|
results := make(chan error, 16)
|
|
|
|
worker := func(iterations int) {
|
|
defer wg.Done()
|
|
for i := 0; i < iterations; i++ {
|
|
cert, err := watcher.GetCertificate(nil)
|
|
if err == nil && cert == nil {
|
|
results <- errGetCertificateFailed
|
|
return
|
|
}
|
|
if err != nil {
|
|
results <- err
|
|
return
|
|
}
|
|
}
|
|
results <- nil
|
|
}
|
|
|
|
for range 8 {
|
|
wg.Add(1)
|
|
go worker(10)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(results)
|
|
|
|
for err := range results {
|
|
if err != nil {
|
|
t.Fatalf("concurrent initial TLS load failed: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestStartCertificateWatcherAddFailureFallsBack(t *testing.T) {
|
|
logger := log.NewLogger("debug", "")
|
|
tempDir := t.TempDir()
|
|
|
|
keyPath := path.Join(tempDir, "key.pem")
|
|
if err := os.WriteFile(keyPath, []byte("key"), 0o600); err != nil {
|
|
t.Fatalf("failed to write key file: %v", err)
|
|
}
|
|
|
|
watcher := NewTlsConfigWatcher(path.Join(tempDir, "missing-cert.pem"), keyPath, logger)
|
|
|
|
if err := watcher.Start(); err == nil {
|
|
t.Fatal("expected watcher setup to fail for missing certificate file")
|
|
}
|
|
if watcher.UseInotify() {
|
|
t.Fatal("expected file watching to be disabled after watcher add failure")
|
|
}
|
|
}
|
|
|
|
func TestCertificateWatcherHandlesAtomicRename(t *testing.T) {
|
|
logger := log.NewLogger("debug", "")
|
|
tempDir := t.TempDir()
|
|
|
|
certPath := path.Join(tempDir, "cert.pem")
|
|
keyPath := path.Join(tempDir, "key.pem")
|
|
|
|
// Generate initial certificates
|
|
caOpts := &tlsutils.CertificateOptions{
|
|
CommonName: "*",
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
KeyType: tlsutils.KeyTypeECDSA,
|
|
}
|
|
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate CA cert: %v", err)
|
|
}
|
|
|
|
serverOpts := &tlsutils.CertificateOptions{
|
|
Hostname: "127.0.0.1",
|
|
CommonName: "*",
|
|
OrganizationalUnit: "TestServer-Initial",
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
KeyType: tlsutils.KeyTypeECDSA,
|
|
}
|
|
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
|
|
t.Fatalf("failed to generate initial server cert: %v", err)
|
|
}
|
|
|
|
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
|
|
|
|
if err := watcher.ReloadCertificate(); err != nil {
|
|
t.Fatalf("failed to load initial certificate: %v", err)
|
|
}
|
|
|
|
if err := watcher.Start(); err != nil {
|
|
t.Fatalf("failed to start certificate watcher: %v", err)
|
|
}
|
|
defer watcher.Stop()
|
|
|
|
watcher.mu.RLock()
|
|
initialModTime := watcher.tlsCertModTime
|
|
watcher.mu.RUnlock()
|
|
|
|
// Simulate atomic certificate replacement via rename (common pattern in k8s, cert-manager, etc.)
|
|
// Sleep long enough to ensure filesystem records different timestamps (at least 1 second
|
|
// to account for systems with second-level timestamp precision)
|
|
time.Sleep(1100 * time.Millisecond)
|
|
|
|
newServerOpts := &tlsutils.CertificateOptions{
|
|
Hostname: "127.0.0.2",
|
|
CommonName: "*",
|
|
OrganizationalUnit: "TestServer-Rotated",
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
KeyType: tlsutils.KeyTypeECDSA,
|
|
}
|
|
|
|
// Write to temp files first
|
|
tempCertPath := certPath + ".tmp"
|
|
tempKeyPath := keyPath + ".tmp"
|
|
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, tempCertPath, tempKeyPath, newServerOpts); err != nil {
|
|
t.Fatalf("failed to generate temp server cert: %v", err)
|
|
}
|
|
|
|
// Atomic rename (this is how most certificate rotation works)
|
|
if err := os.Rename(tempCertPath, certPath); err != nil {
|
|
t.Fatalf("failed to rename cert: %v", err)
|
|
}
|
|
if err := os.Rename(tempKeyPath, keyPath); err != nil {
|
|
t.Fatalf("failed to rename key: %v", err)
|
|
}
|
|
|
|
// Poll until the certificate is reloaded or timeout is reached.
|
|
// This accounts for fsnotify delivery time plus the 150ms debounce interval,
|
|
// which can exceed fixed sleeps on slower systems.
|
|
deadline := time.Now().Add(2 * time.Second)
|
|
var newModTime time.Time
|
|
reloaded := false
|
|
for time.Now().Before(deadline) {
|
|
watcher.mu.RLock()
|
|
newModTime = watcher.tlsCertModTime
|
|
watcher.mu.RUnlock()
|
|
|
|
if newModTime.After(initialModTime) {
|
|
reloaded = true
|
|
break
|
|
}
|
|
time.Sleep(50 * time.Millisecond)
|
|
}
|
|
|
|
if !reloaded {
|
|
t.Fatalf("expected certificate to be reloaded after atomic rename within 2s, initial: %v, new: %v",
|
|
initialModTime, newModTime)
|
|
}
|
|
}
|
|
|
|
func TestCertificateWatcherCanRestart(t *testing.T) {
|
|
logger := log.NewLogger("debug", "")
|
|
tempDir := t.TempDir()
|
|
|
|
certPath := path.Join(tempDir, "cert.pem")
|
|
keyPath := path.Join(tempDir, "key.pem")
|
|
|
|
caOpts := &tlsutils.CertificateOptions{
|
|
CommonName: "*",
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
KeyType: tlsutils.KeyTypeECDSA,
|
|
}
|
|
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate CA cert: %v", err)
|
|
}
|
|
|
|
serverOpts := &tlsutils.CertificateOptions{
|
|
Hostname: "127.0.0.1",
|
|
CommonName: "*",
|
|
OrganizationalUnit: "TestServer",
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
KeyType: tlsutils.KeyTypeECDSA,
|
|
}
|
|
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
|
|
t.Fatalf("failed to generate server cert: %v", err)
|
|
}
|
|
|
|
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
|
|
|
|
// Load initial certificate
|
|
if err := watcher.ReloadCertificate(); err != nil {
|
|
t.Fatalf("failed to load initial certificate: %v", err)
|
|
}
|
|
|
|
// Start the watcher
|
|
if err := watcher.Start(); err != nil {
|
|
t.Fatalf("failed to start certificate watcher: %v", err)
|
|
}
|
|
|
|
// Verify UseInotify returns true while running
|
|
if !watcher.UseInotify() {
|
|
t.Fatalf("expected UseInotify() to be true after Start()")
|
|
}
|
|
|
|
// Stop the watcher
|
|
watcher.Stop()
|
|
time.Sleep(100 * time.Millisecond) // Allow cleanup to complete
|
|
|
|
// Verify UseInotify returns false after Stop
|
|
if watcher.UseInotify() {
|
|
t.Fatalf("expected UseInotify() to be false after Stop()")
|
|
}
|
|
|
|
// Verify we can call Start() again (should not return error)
|
|
if err := watcher.Start(); err != nil {
|
|
t.Fatalf("failed to restart certificate watcher: %v", err)
|
|
}
|
|
|
|
// Verify UseInotify is true again
|
|
if !watcher.UseInotify() {
|
|
t.Fatalf("expected UseInotify() to be true after restart")
|
|
}
|
|
|
|
watcher.Stop()
|
|
}
|
|
|
|
func TestCanonicalOrigin(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
parsed *url.URL
|
|
wantOrig string
|
|
wantOK bool
|
|
}{
|
|
{"nil URL", nil, "", false},
|
|
{"non-http(s) scheme (ftp)", mustParseURL("ftp://example.com"), "", false},
|
|
{"non-http(s) scheme (javascript)", mustParseURL("javascript:alert(1)"), "", false},
|
|
{"empty scheme", mustParseURL("//example.com"), "", false},
|
|
{"empty hostname (port only)", mustParseURL("http://:8080/"), "", false},
|
|
{"valid http default port", mustParseURL("http://example.com"), "http://example.com:80", true},
|
|
{"valid http explicit port", mustParseURL("http://example.com:8080"), "http://example.com:8080", true},
|
|
{"valid https default port", mustParseURL("https://example.com"), "https://example.com:443", true},
|
|
{"valid https explicit port", mustParseURL("https://example.com:8443"), "https://example.com:8443", true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
gotOrig, gotOK := canonicalOrigin(tt.parsed)
|
|
if gotOrig != tt.wantOrig || gotOK != tt.wantOK {
|
|
t.Errorf("canonicalOrigin() = %q, %v, want %q, %v", gotOrig, gotOK, tt.wantOrig, tt.wantOK)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCanonicalOriginString(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
raw string
|
|
want string
|
|
ok bool
|
|
}{
|
|
{"empty", "", "", false},
|
|
{"whitespace only", " \t ", "", false},
|
|
{"relative (no scheme)", "example.com/path", "", false},
|
|
{"path only", "/v2/", "", false},
|
|
{"scheme but no host", "http://", "", false},
|
|
{"non-http(s) URL", "ftp://example.com", "", false},
|
|
{"empty hostname with port", "http://:80/", "", false},
|
|
{"invalid host", "http://:/", "", false},
|
|
{"valid https", "https://example.com", "https://example.com:443", true},
|
|
{"valid http with port", "http://localhost:3000", "http://localhost:3000", true},
|
|
{"trimmed", " https://example.com ", "https://example.com:443", true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, ok := canonicalOriginString(tt.raw)
|
|
if got != tt.want || ok != tt.ok {
|
|
t.Errorf("canonicalOriginString(%q) = %q, %v, want %q, %v", tt.raw, got, ok, tt.want, tt.ok)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func mustParseURL(s string) *url.URL {
|
|
u, err := url.Parse(s)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return u
|
|
}
|