Files
zot/pkg/api/controller_internal_test.go
T
Andrei Aaron 9425ca8b7d fix(auth): prevent open redirect via callback_ui (#3844)
Validate callback_ui and default invalid values to /.
Allow absolute callback_ui only when its origin is allowlisted via http.auth.openid.callbackAllowOrigins (and externalUrl).
Add/adjust unit + controller tests and update examples/docs for relative vs allowlisted absolute redirect

Signed-off-by: Andrei Aaron <andreifdaaron@gmail.com>
2026-03-08 08:13:16 +02:00

471 lines
13 KiB
Go

//go:build sync && scrub && metrics && search && lint && userprefs && mgmt && imagetrust && ui
package api
import (
goerrors "errors"
"net/url"
"os"
"path"
"sync"
"testing"
"time"
"zotregistry.dev/zot/v2/pkg/log"
tlsutils "zotregistry.dev/zot/v2/pkg/test/tls"
)
var errGetCertificateFailed = goerrors.New("GetCertificate failed")
func TestReloadCertificateStatFailureKeepsModTimes(t *testing.T) {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
certPath := path.Join(tempDir, "cert.pem")
keyPath := path.Join(tempDir, "key.pem")
caOpts := &tlsutils.CertificateOptions{
CommonName: "*",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
if err != nil {
t.Fatalf("failed to generate CA cert: %v", err)
}
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "*",
OrganizationalUnit: "TestServer",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
t.Fatalf("failed to generate server cert: %v", err)
}
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
if err := watcher.ReloadCertificate(); err != nil {
t.Fatalf("failed to load initial certificate: %v", err)
}
watcher.mu.RLock()
initialCertModTime := watcher.tlsCertModTime
initialKeyModTime := watcher.tlsKeyModTime
watcher.mu.RUnlock()
if initialCertModTime.IsZero() || initialKeyModTime.IsZero() {
t.Fatal("expected initial mod times to be set")
}
originalStat := tlsFileStat
tlsFileStat = func(string) (os.FileInfo, error) {
return nil, os.ErrNotExist
}
t.Cleanup(func() {
tlsFileStat = originalStat
})
time.Sleep(10 * time.Millisecond)
if err := watcher.ReloadCertificate(); err != nil {
t.Fatalf("unexpected reload error when stat fails: %v", err)
}
watcher.mu.RLock()
updatedCertModTime := watcher.tlsCertModTime
updatedKeyModTime := watcher.tlsKeyModTime
watcher.mu.RUnlock()
if !updatedCertModTime.Equal(initialCertModTime) {
t.Fatal("certificate mod time changed despite stat failure")
}
if !updatedKeyModTime.Equal(initialKeyModTime) {
t.Fatal("key mod time changed despite stat failure")
}
}
func TestGetCertificateReloadConcurrency(t *testing.T) {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
certPath := path.Join(tempDir, "cert.pem")
keyPath := path.Join(tempDir, "key.pem")
caOpts := &tlsutils.CertificateOptions{
CommonName: "*",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
if err != nil {
t.Fatalf("failed to generate CA cert: %v", err)
}
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "*",
OrganizationalUnit: "TestServer",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
t.Fatalf("failed to generate server cert: %v", err)
}
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
if err := watcher.ReloadCertificate(); err != nil {
t.Fatalf("failed to load initial certificate: %v", err)
}
var wg sync.WaitGroup
errorCh := make(chan error, 32)
reloadWorker := func(iterations int) {
defer wg.Done()
for i := 0; i < iterations; i++ {
if err := watcher.ReloadCertificate(); err != nil {
errorCh <- err
return
}
}
}
getWorker := func(iterations int) {
defer wg.Done()
for i := 0; i < iterations; i++ {
cert, err := watcher.GetCertificate(nil)
if err != nil || cert == nil {
errorCh <- err
return
}
}
}
wg.Add(3)
go reloadWorker(50)
go getWorker(100)
go getWorker(100)
wg.Wait()
close(errorCh)
for err := range errorCh {
if err != nil {
t.Fatalf("concurrent TLS operations failed: %v", err)
}
}
}
func TestGetCertificateInitialLoadConcurrency(t *testing.T) {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
certPath := path.Join(tempDir, "cert.pem")
keyPath := path.Join(tempDir, "key.pem")
caOpts := &tlsutils.CertificateOptions{
CommonName: "*",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
if err != nil {
t.Fatalf("failed to generate CA cert: %v", err)
}
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "*",
OrganizationalUnit: "TestServer",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
t.Fatalf("failed to generate server cert: %v", err)
}
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
var wg sync.WaitGroup
results := make(chan error, 16)
worker := func(iterations int) {
defer wg.Done()
for i := 0; i < iterations; i++ {
cert, err := watcher.GetCertificate(nil)
if err == nil && cert == nil {
results <- errGetCertificateFailed
return
}
if err != nil {
results <- err
return
}
}
results <- nil
}
for range 8 {
wg.Add(1)
go worker(10)
}
wg.Wait()
close(results)
for err := range results {
if err != nil {
t.Fatalf("concurrent initial TLS load failed: %v", err)
}
}
}
func TestStartCertificateWatcherAddFailureFallsBack(t *testing.T) {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
keyPath := path.Join(tempDir, "key.pem")
if err := os.WriteFile(keyPath, []byte("key"), 0o600); err != nil {
t.Fatalf("failed to write key file: %v", err)
}
watcher := NewTlsConfigWatcher(path.Join(tempDir, "missing-cert.pem"), keyPath, logger)
if err := watcher.Start(); err == nil {
t.Fatal("expected watcher setup to fail for missing certificate file")
}
if watcher.UseInotify() {
t.Fatal("expected file watching to be disabled after watcher add failure")
}
}
func TestCertificateWatcherHandlesAtomicRename(t *testing.T) {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
certPath := path.Join(tempDir, "cert.pem")
keyPath := path.Join(tempDir, "key.pem")
// Generate initial certificates
caOpts := &tlsutils.CertificateOptions{
CommonName: "*",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
if err != nil {
t.Fatalf("failed to generate CA cert: %v", err)
}
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "*",
OrganizationalUnit: "TestServer-Initial",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
t.Fatalf("failed to generate initial server cert: %v", err)
}
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
if err := watcher.ReloadCertificate(); err != nil {
t.Fatalf("failed to load initial certificate: %v", err)
}
if err := watcher.Start(); err != nil {
t.Fatalf("failed to start certificate watcher: %v", err)
}
defer watcher.Stop()
watcher.mu.RLock()
initialModTime := watcher.tlsCertModTime
watcher.mu.RUnlock()
// Simulate atomic certificate replacement via rename (common pattern in k8s, cert-manager, etc.)
// Sleep long enough to ensure filesystem records different timestamps (at least 1 second
// to account for systems with second-level timestamp precision)
time.Sleep(1100 * time.Millisecond)
newServerOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.2",
CommonName: "*",
OrganizationalUnit: "TestServer-Rotated",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
// Write to temp files first
tempCertPath := certPath + ".tmp"
tempKeyPath := keyPath + ".tmp"
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, tempCertPath, tempKeyPath, newServerOpts); err != nil {
t.Fatalf("failed to generate temp server cert: %v", err)
}
// Atomic rename (this is how most certificate rotation works)
if err := os.Rename(tempCertPath, certPath); err != nil {
t.Fatalf("failed to rename cert: %v", err)
}
if err := os.Rename(tempKeyPath, keyPath); err != nil {
t.Fatalf("failed to rename key: %v", err)
}
// Poll until the certificate is reloaded or timeout is reached.
// This accounts for fsnotify delivery time plus the 150ms debounce interval,
// which can exceed fixed sleeps on slower systems.
deadline := time.Now().Add(2 * time.Second)
var newModTime time.Time
reloaded := false
for time.Now().Before(deadline) {
watcher.mu.RLock()
newModTime = watcher.tlsCertModTime
watcher.mu.RUnlock()
if newModTime.After(initialModTime) {
reloaded = true
break
}
time.Sleep(50 * time.Millisecond)
}
if !reloaded {
t.Fatalf("expected certificate to be reloaded after atomic rename within 2s, initial: %v, new: %v",
initialModTime, newModTime)
}
}
func TestCertificateWatcherCanRestart(t *testing.T) {
logger := log.NewLogger("debug", "")
tempDir := t.TempDir()
certPath := path.Join(tempDir, "cert.pem")
keyPath := path.Join(tempDir, "key.pem")
caOpts := &tlsutils.CertificateOptions{
CommonName: "*",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
caCertPEM, caKeyPEM, err := tlsutils.GenerateCACert(caOpts)
if err != nil {
t.Fatalf("failed to generate CA cert: %v", err)
}
serverOpts := &tlsutils.CertificateOptions{
Hostname: "127.0.0.1",
CommonName: "*",
OrganizationalUnit: "TestServer",
NotAfter: time.Now().AddDate(10, 0, 0),
KeyType: tlsutils.KeyTypeECDSA,
}
if err := tlsutils.GenerateServerCertToFile(caCertPEM, caKeyPEM, certPath, keyPath, serverOpts); err != nil {
t.Fatalf("failed to generate server cert: %v", err)
}
watcher := NewTlsConfigWatcher(certPath, keyPath, logger)
// Load initial certificate
if err := watcher.ReloadCertificate(); err != nil {
t.Fatalf("failed to load initial certificate: %v", err)
}
// Start the watcher
if err := watcher.Start(); err != nil {
t.Fatalf("failed to start certificate watcher: %v", err)
}
// Verify UseInotify returns true while running
if !watcher.UseInotify() {
t.Fatalf("expected UseInotify() to be true after Start()")
}
// Stop the watcher
watcher.Stop()
time.Sleep(100 * time.Millisecond) // Allow cleanup to complete
// Verify UseInotify returns false after Stop
if watcher.UseInotify() {
t.Fatalf("expected UseInotify() to be false after Stop()")
}
// Verify we can call Start() again (should not return error)
if err := watcher.Start(); err != nil {
t.Fatalf("failed to restart certificate watcher: %v", err)
}
// Verify UseInotify is true again
if !watcher.UseInotify() {
t.Fatalf("expected UseInotify() to be true after restart")
}
watcher.Stop()
}
func TestCanonicalOrigin(t *testing.T) {
tests := []struct {
name string
parsed *url.URL
wantOrig string
wantOK bool
}{
{"nil URL", nil, "", false},
{"non-http(s) scheme (ftp)", mustParseURL("ftp://example.com"), "", false},
{"non-http(s) scheme (javascript)", mustParseURL("javascript:alert(1)"), "", false},
{"empty scheme", mustParseURL("//example.com"), "", false},
{"empty hostname (port only)", mustParseURL("http://:8080/"), "", false},
{"valid http default port", mustParseURL("http://example.com"), "http://example.com:80", true},
{"valid http explicit port", mustParseURL("http://example.com:8080"), "http://example.com:8080", true},
{"valid https default port", mustParseURL("https://example.com"), "https://example.com:443", true},
{"valid https explicit port", mustParseURL("https://example.com:8443"), "https://example.com:8443", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotOrig, gotOK := canonicalOrigin(tt.parsed)
if gotOrig != tt.wantOrig || gotOK != tt.wantOK {
t.Errorf("canonicalOrigin() = %q, %v, want %q, %v", gotOrig, gotOK, tt.wantOrig, tt.wantOK)
}
})
}
}
func TestCanonicalOriginString(t *testing.T) {
tests := []struct {
name string
raw string
want string
ok bool
}{
{"empty", "", "", false},
{"whitespace only", " \t ", "", false},
{"relative (no scheme)", "example.com/path", "", false},
{"path only", "/v2/", "", false},
{"scheme but no host", "http://", "", false},
{"non-http(s) URL", "ftp://example.com", "", false},
{"empty hostname with port", "http://:80/", "", false},
{"invalid host", "http://:/", "", false},
{"valid https", "https://example.com", "https://example.com:443", true},
{"valid http with port", "http://localhost:3000", "http://localhost:3000", true},
{"trimmed", " https://example.com ", "https://example.com:443", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := canonicalOriginString(tt.raw)
if got != tt.want || ok != tt.ok {
t.Errorf("canonicalOriginString(%q) = %q, %v, want %q, %v", tt.raw, got, ok, tt.want, tt.ok)
}
})
}
}
func mustParseURL(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}