From 9425ca8b7d0d121ae4d32a03beae34e7da0fb7fe Mon Sep 17 00:00:00 2001 From: Andrei Aaron Date: Sun, 8 Mar 2026 08:13:16 +0200 Subject: [PATCH] fix(auth): prevent open redirect via callback_ui (#3844) Validate callback_ui and default invalid values to /. Allow absolute callback_ui only when its origin is allowlisted via http.auth.openid.callbackAllowOrigins (and externalUrl). Add/adjust unit + controller tests and update examples/docs for relative vs allowlisted absolute redirect Signed-off-by: Andrei Aaron --- examples/README.md | 20 +++- examples/config-openid-claim-mapping.json | 1 + examples/config-openid.json | 1 + pkg/api/authn.go | 127 +++++++++++++++++++++- pkg/api/authn_test.go | 55 ++++++++++ pkg/api/config/config.go | 9 +- pkg/api/constants/consts.go | 2 + pkg/api/controller_internal_test.go | 65 +++++++++++ pkg/api/controller_test.go | 96 +++++++++++++++- pkg/api/proxy.go | 4 +- 10 files changed, 368 insertions(+), 12 deletions(-) diff --git a/examples/README.md b/examples/README.md index 9dc66f39..c3af4416 100644 --- a/examples/README.md +++ b/examples/README.md @@ -332,11 +332,27 @@ zot can be configured to use the above providers with: } ``` -To login with either provider use http://127.0.0.1:8080/zot/auth/login?provider=\&callback_ui=http://127.0.0.1:8080/home -for example to login with github use http://127.0.0.1:8080/zot/auth/login?provider=github&callback_ui=http://127.0.0.1:8080/home +To login with either provider use http://127.0.0.1:8080/zot/auth/login?provider=\&callback_ui=/home +for example to login with github use http://127.0.0.1:8080/zot/auth/login?provider=github&callback_ui=/home callback_ui query parameter is used by zot to redirect to UI after a successful openid/oauth2 authentication +By default, `callback_ui` must be a relative path (starting with `/`) to prevent open redirects. +If your UI runs on a different origin (e.g. different port during development), you can allowlist +absolute redirect origins via: + +``` +{ + "http": { + "auth": { + "openid": { + "callbackAllowOrigins": ["http://127.0.0.1:3000"] + } + } + } +} +``` + The callback url which should be used when making oauth2 provider setup is http://127.0.0.1:8080/zot/auth/callback/\ for example github callback url would be http://127.0.0.1:8080/zot/auth/callback/github diff --git a/examples/config-openid-claim-mapping.json b/examples/config-openid-claim-mapping.json index 3896e8b8..bc0e9c24 100644 --- a/examples/config-openid-claim-mapping.json +++ b/examples/config-openid-claim-mapping.json @@ -12,6 +12,7 @@ "auth": { "sessionKeysFile": "examples/sessionKeys.json", "openid": { + "callbackAllowOrigins": ["http://127.0.0.1:3000"], "providers": { "oidc": { "name": "Zitadel", diff --git a/examples/config-openid.json b/examples/config-openid.json index 20899c4c..32813f46 100644 --- a/examples/config-openid.json +++ b/examples/config-openid.json @@ -16,6 +16,7 @@ "sessionKeysFile": "examples/sessionKeys.json", "apikey": true, "openid": { + "callbackAllowOrigins": ["http://127.0.0.1:3000"], "providers": { "github": { "credentialsFile": "examples/config-openid-github-credentials.json", diff --git a/pkg/api/authn.go b/pkg/api/authn.go index c382f56c..be77a73d 100644 --- a/pkg/api/authn.go +++ b/pkg/api/authn.go @@ -13,6 +13,7 @@ import ( "fmt" "net" "net/http" + "net/url" "os" "regexp" "slices" @@ -667,10 +668,130 @@ func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc { } } +func canonicalOrigin(parsedURL *url.URL) (string, bool) { + if parsedURL == nil { + return "", false + } + + scheme := strings.ToLower(parsedURL.Scheme) + if scheme != constants.SchemeHTTP && scheme != constants.SchemeHTTPS { + return "", false + } + + host := strings.ToLower(parsedURL.Hostname()) + if host == "" { + return "", false + } + + port := parsedURL.Port() + if port == "" { + if scheme == constants.SchemeHTTP { + port = "80" + } else { + port = "443" + } + } + + return scheme + "://" + net.JoinHostPort(host, port), true +} + +func canonicalOriginString(raw string) (string, bool) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", false + } + + parsed, err := url.Parse(raw) + if err != nil { + return "", false + } + + // Only accept absolute http(s) URLs for allowlist entries. + if parsed.Scheme == "" || parsed.Host == "" { + return "", false + } + + return canonicalOrigin(parsed) +} + +// ValidateCallbackUI validates the callback_ui parameter used for post-login redirects. +// - Relative paths (starting with "/") are always allowed. +// - Absolute http(s) URLs are allowed only when their origin matches allowOrigins. +// It returns the validated redirect target, or "/" as fallback, or "" if the input is empty. +func ValidateCallbackUI(callbackUI string, allowOrigins []string) string { + if callbackUI == "" { + return "" + } + + // Prevent header injection. + if strings.ContainsAny(callbackUI, "\r\n") { + return "/" + } + + parsed, err := url.Parse(callbackUI) + if err != nil { + return "/" + } + + // Reject protocol-relative URLs (e.g. //evil.com/path) + if strings.HasPrefix(callbackUI, "//") { + return "/" + } + + // Relative path to root (safe default). + if parsed.Scheme == "" && parsed.Host == "" { + if !strings.HasPrefix(callbackUI, "/") { + return "/" + } + + return callbackUI + } + + // Absolute URL: only allow http(s) and only when origin is allowlisted. + if parsed.Scheme != constants.SchemeHTTP && parsed.Scheme != constants.SchemeHTTPS { + return "/" + } + + if parsed.Host == "" { + return "/" + } + + origin, ok := canonicalOrigin(parsed) + if !ok { + return "/" + } + + for _, rawAllowed := range allowOrigins { + allowedOrigin, ok := canonicalOriginString(rawAllowed) + if !ok { + continue + } + + if allowedOrigin == origin { + return callbackUI + } + } + + return "/" +} + func (rh *RouteHandler) AuthURLHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() - callbackUI := query.Get(constants.CallbackUIQueryParam) + + allowOrigins := []string{} + if authCfg := rh.c.Config.CopyAuthConfig(); authCfg != nil { + if authCfg.OpenID != nil { + allowOrigins = append(allowOrigins, authCfg.OpenID.CallbackAllowOrigins...) + } + } + + // If an ExternalURL is configured, allow redirects back to that origin. + if rh.c.Config.HTTP.ExternalURL != "" { + allowOrigins = append(allowOrigins, rh.c.Config.HTTP.ExternalURL) + } + + callbackUI := ValidateCallbackUI(query.Get(constants.CallbackUIQueryParam), allowOrigins) provider := query.Get("provider") @@ -794,9 +915,9 @@ func getRelyingPartyArgs(cfg *config.Config, provider string, hashKey, encryptKe externalURL := strings.TrimSuffix(cfg.HTTP.ExternalURL, "/") redirectURI = fmt.Sprintf("%s%s", externalURL, callback) } else { - scheme := "http" + scheme := constants.SchemeHTTP if cfg.HTTP.TLS != nil { - scheme = "https" + scheme = constants.SchemeHTTPS } redirectURI = fmt.Sprintf("%s://%s%s", scheme, baseURL, callback) diff --git a/pkg/api/authn_test.go b/pkg/api/authn_test.go index ee1bff7b..fa708fee 100644 --- a/pkg/api/authn_test.go +++ b/pkg/api/authn_test.go @@ -96,6 +96,61 @@ func TestAllowedMethodsHeaderAPIKey(t *testing.T) { }) } +func TestValidateCallbackUI(t *testing.T) { + tests := []struct { + name string + input string + allowOrigins []string + expected string + }{ + {name: "empty", input: "", expected: ""}, + {name: "relative path", input: "/v2/", expected: "/v2/"}, + {name: "root path", input: "/", expected: "/"}, + {name: "relative with path", input: "/zot/auth/login", expected: "/zot/auth/login"}, + {name: "absolute URL rejected (not allowlisted)", input: "https://evil.com/phish", expected: "/"}, + { + name: "absolute URL allowed when allowlisted (https default port)", + input: "https://example.com/home", + allowOrigins: []string{"https://example.com"}, + expected: "https://example.com/home", + }, + { + name: "absolute URL allowed when allowlisted (explicit port)", + input: "http://localhost:3000/home", + allowOrigins: []string{"http://localhost:3000"}, + expected: "http://localhost:3000/home", + }, + { + name: "absolute URL rejected when port differs", + input: "http://localhost:3001/home", + allowOrigins: []string{"http://localhost:3000"}, + expected: "/", + }, + {name: "protocol-relative rejected", input: "//evil.com/path", expected: "/"}, + {name: "no leading slash rejected", input: "v2/", expected: "/"}, + {name: "relative path without leading slash rejected", input: "path/segment", expected: "/"}, + {name: "javascript scheme rejected", input: "javascript:alert(1)", expected: "/"}, + {name: "absolute URL with empty host rejected", input: "http:///path", expected: "/"}, + { + name: "allowlist entry invalid causes continue then match", + input: "https://example.com/home", + allowOrigins: []string{" \t ", "https://example.com"}, + expected: "https://example.com/home", + }, + {name: "header injection rejected (newline)", input: "/v2/\nSet-Cookie: x=y", expected: "/"}, + {name: "header injection rejected (carriage return)", input: "/v2/\rSet-Cookie: x=y", expected: "/"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := api.ValidateCallbackUI(tt.input, tt.allowOrigins) + if got != tt.expected { + t.Errorf("ValidateCallbackUI(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} + func TestAPIKeys(t *testing.T) { Convey("Make a new controller", t, func() { port := test.GetFreePort() diff --git a/pkg/api/config/config.go b/pkg/api/config/config.go index 5644ed31..b2b1db69 100644 --- a/pkg/api/config/config.go +++ b/pkg/api/config/config.go @@ -4,6 +4,7 @@ import ( "encoding/json" "maps" "os" + "slices" "sync" "time" @@ -279,6 +280,11 @@ type SessionKeys struct { type OpenIDConfig struct { Providers map[string]OpenIDProviderConfig + // CallbackAllowOrigins is an allowlist of absolute URL origins that are permitted in the + // callback_ui query parameter during the OpenID/OAuth2 login flow. If empty, callback_ui must + // be a same-origin relative path (e.g. "/v2/") to prevent open redirects. + // Example: ["http://localhost:3000", "https://ui.example.com"] + CallbackAllowOrigins []string `mapstructure:"callbackAllowOrigins,omitempty"` } type OpenIDCredentials struct { @@ -720,7 +726,8 @@ func (c *Config) Sanitize() *Config { // Sanitize OpenID client secrets if c.HTTP.Auth.OpenID != nil { sanitizedConfig.HTTP.Auth.OpenID = &OpenIDConfig{ - Providers: make(map[string]OpenIDProviderConfig), + Providers: make(map[string]OpenIDProviderConfig), + CallbackAllowOrigins: slices.Clone(c.HTTP.Auth.OpenID.CallbackAllowOrigins), } for provider, config := range c.HTTP.Auth.OpenID.Providers { diff --git a/pkg/api/constants/consts.go b/pkg/api/constants/consts.go index a9896012..ef45dda0 100644 --- a/pkg/api/constants/consts.go +++ b/pkg/api/constants/consts.go @@ -22,6 +22,8 @@ const ( SessionClientHeaderValue = "zot-ui" APIKeysPrefix = "zak_" CallbackUIQueryParam = "callback_ui" + SchemeHTTP = "http" + SchemeHTTPS = "https" APIKeyTimeFormat = time.RFC3339 // CreatePermission is an authz permission for create actions. CreatePermission = "create" diff --git a/pkg/api/controller_internal_test.go b/pkg/api/controller_internal_test.go index f43d3d93..bce4656a 100644 --- a/pkg/api/controller_internal_test.go +++ b/pkg/api/controller_internal_test.go @@ -4,6 +4,7 @@ package api import ( goerrors "errors" + "net/url" "os" "path" "sync" @@ -403,3 +404,67 @@ func TestCertificateWatcherCanRestart(t *testing.T) { watcher.Stop() } + +func TestCanonicalOrigin(t *testing.T) { + tests := []struct { + name string + parsed *url.URL + wantOrig string + wantOK bool + }{ + {"nil URL", nil, "", false}, + {"non-http(s) scheme (ftp)", mustParseURL("ftp://example.com"), "", false}, + {"non-http(s) scheme (javascript)", mustParseURL("javascript:alert(1)"), "", false}, + {"empty scheme", mustParseURL("//example.com"), "", false}, + {"empty hostname (port only)", mustParseURL("http://:8080/"), "", false}, + {"valid http default port", mustParseURL("http://example.com"), "http://example.com:80", true}, + {"valid http explicit port", mustParseURL("http://example.com:8080"), "http://example.com:8080", true}, + {"valid https default port", mustParseURL("https://example.com"), "https://example.com:443", true}, + {"valid https explicit port", mustParseURL("https://example.com:8443"), "https://example.com:8443", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotOrig, gotOK := canonicalOrigin(tt.parsed) + if gotOrig != tt.wantOrig || gotOK != tt.wantOK { + t.Errorf("canonicalOrigin() = %q, %v, want %q, %v", gotOrig, gotOK, tt.wantOrig, tt.wantOK) + } + }) + } +} + +func TestCanonicalOriginString(t *testing.T) { + tests := []struct { + name string + raw string + want string + ok bool + }{ + {"empty", "", "", false}, + {"whitespace only", " \t ", "", false}, + {"relative (no scheme)", "example.com/path", "", false}, + {"path only", "/v2/", "", false}, + {"scheme but no host", "http://", "", false}, + {"non-http(s) URL", "ftp://example.com", "", false}, + {"empty hostname with port", "http://:80/", "", false}, + {"invalid host", "http://:/", "", false}, + {"valid https", "https://example.com", "https://example.com:443", true}, + {"valid http with port", "http://localhost:3000", "http://localhost:3000", true}, + {"trimmed", " https://example.com ", "https://example.com:443", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := canonicalOriginString(tt.raw) + if got != tt.want || ok != tt.ok { + t.Errorf("canonicalOriginString(%q) = %q, %v, want %q, %v", tt.raw, got, ok, tt.want, tt.ok) + } + }) + } +} + +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +} diff --git a/pkg/api/controller_test.go b/pkg/api/controller_test.go index 7e34fc11..a9e496e3 100644 --- a/pkg/api/controller_test.go +++ b/pkg/api/controller_test.go @@ -4208,8 +4208,21 @@ func TestOpenIDMiddleware(t *testing.T) { client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) client.SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue) - Convey("with callback_ui value provided", func() { - // first login user + Convey("with relative callback_ui value provided", func() { + // first login user (callback_ui must be relative path to prevent open redirect) + resp, err := client.R(). + SetQueryParam("provider", "oidc"). + SetQueryParam("callback_ui", "/v2/"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + }) + + Convey("with absolute callback_ui value provided and allowlisted", func() { + // allow absolute redirects only to allowlisted UI origins + conf.HTTP.Auth.OpenID.CallbackAllowOrigins = []string{baseURL} + resp, err := client.R(). SetQueryParam("provider", "oidc"). SetQueryParam("callback_ui", baseURL+"/v2/"). @@ -4219,6 +4232,37 @@ func TestOpenIDMiddleware(t *testing.T) { So(resp.StatusCode(), ShouldEqual, http.StatusOK) }) + Convey("with absolute callback_ui value provided and NOT allowlisted", func() { + // If an external redirect is attempted, resty would try to connect to this unreachable address. + evil := "http://127.0.0.1:1/phished" + + resp, err := client.R(). + SetQueryParam("provider", "oidc"). + SetQueryParam("callback_ui", evil). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + So(resp.RawResponse, ShouldNotBeNil) + So(resp.RawResponse.Request, ShouldNotBeNil) + So(resp.RawResponse.Request.URL.String(), ShouldStartWith, baseURL) + }) + + Convey("with protocol-relative callback_ui value provided", func() { + evil := "//127.0.0.1:1/phished" + + resp, err := client.R(). + SetQueryParam("provider", "oidc"). + SetQueryParam("callback_ui", evil). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + So(resp.RawResponse, ShouldNotBeNil) + So(resp.RawResponse.Request, ShouldNotBeNil) + So(resp.RawResponse.Request.URL.String(), ShouldStartWith, baseURL) + }) + // first login user resp, err := client.R(). SetQueryParam("provider", "oidc"). @@ -4617,8 +4661,21 @@ func TestOpenIDMiddlewareWithRedisSessionDriver(t *testing.T) { client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) client.SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue) - Convey("with callback_ui value provided", func() { - // first login user + Convey("with relative callback_ui value provided", func() { + // first login user (callback_ui must be relative path to prevent open redirect) + resp, err := client.R(). + SetQueryParam("provider", "oidc"). + SetQueryParam("callback_ui", "/v2/"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + }) + + Convey("with absolute callback_ui value provided and allowlisted", func() { + // allow absolute redirects only to allowlisted UI origins + conf.HTTP.Auth.OpenID.CallbackAllowOrigins = []string{baseURL} + resp, err := client.R(). SetQueryParam("provider", "oidc"). SetQueryParam("callback_ui", baseURL+"/v2/"). @@ -4628,6 +4685,37 @@ func TestOpenIDMiddlewareWithRedisSessionDriver(t *testing.T) { So(resp.StatusCode(), ShouldEqual, http.StatusOK) }) + Convey("with absolute callback_ui value provided and NOT allowlisted", func() { + // If an external redirect is attempted, resty would try to connect to this unreachable address. + evil := "http://127.0.0.1:1/phished" + + resp, err := client.R(). + SetQueryParam("provider", "oidc"). + SetQueryParam("callback_ui", evil). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + So(resp.RawResponse, ShouldNotBeNil) + So(resp.RawResponse.Request, ShouldNotBeNil) + So(resp.RawResponse.Request.URL.String(), ShouldStartWith, baseURL) + }) + + Convey("with protocol-relative callback_ui value provided", func() { + evil := "//127.0.0.1:1/phished" + + resp, err := client.R(). + SetQueryParam("provider", "oidc"). + SetQueryParam("callback_ui", evil). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + So(resp.RawResponse, ShouldNotBeNil) + So(resp.RawResponse.Request, ShouldNotBeNil) + So(resp.RawResponse.Request.URL.String(), ShouldStartWith, baseURL) + }) + // first login user resp, err := client.R(). SetQueryParam("provider", "oidc"). diff --git a/pkg/api/proxy.go b/pkg/api/proxy.go index 3ede34ac..cac55943 100644 --- a/pkg/api/proxy.go +++ b/pkg/api/proxy.go @@ -123,9 +123,9 @@ func proxyHTTPRequest(ctx context.Context, req *http.Request, // Get HTTP TLS config safely httpTLSConfig := ctrlr.Config.CopyTLSConfig() - proxyQueryScheme := "http" + proxyQueryScheme := constants.SchemeHTTP if httpTLSConfig != nil { - proxyQueryScheme = "https" + proxyQueryScheme = constants.SchemeHTTPS } cloneURL.Scheme = proxyQueryScheme