Files
zot/pkg/api/authn.go
T
uaggarwa 66e9cfb01f feat(metrics): anonymous access when enabled in accessControl config (#4110)
* feat: add anonymouspolicy support in metrics

Signed-off-by: uaggarwa <uaggarwa@akamai.com>

* test: add unit tests

Signed-off-by: uaggarwa <uaggarwa@akamai.com>

---------

Signed-off-by: uaggarwa <uaggarwa@akamai.com>
2026-06-10 10:19:28 +03:00

1534 lines
45 KiB
Go

package api
import (
"context"
"crypto"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"regexp"
"slices"
"strconv"
"strings"
"time"
"github.com/go-jose/go-jose/v4"
guuid "github.com/gofrs/uuid"
"github.com/golang-jwt/jwt/v5"
"github.com/google/go-github/v62/github"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
godigest "github.com/opencontainers/go-digest"
"github.com/zitadel/oidc/v3/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
githubOAuth "golang.org/x/oauth2/github"
zerr "zotregistry.dev/zot/v2/errors"
"zotregistry.dev/zot/v2/pkg/api/config"
"zotregistry.dev/zot/v2/pkg/api/constants"
apiErr "zotregistry.dev/zot/v2/pkg/api/errors"
zcommon "zotregistry.dev/zot/v2/pkg/common"
extconf "zotregistry.dev/zot/v2/pkg/extensions/config"
"zotregistry.dev/zot/v2/pkg/log"
reqCtx "zotregistry.dev/zot/v2/pkg/requestcontext"
)
const (
issuedAtOffset = 5 * time.Second
relyingPartyCookieMaxAge = 120
)
type AuthnMiddleware struct {
htpasswd *HTPasswd
ldapClient *LDAPClient
log log.Logger
}
func AuthHandler(ctlr *Controller) mux.MiddlewareFunc {
authnMiddleware := &AuthnMiddleware{
htpasswd: ctlr.HTPasswd,
log: ctlr.Log,
}
authConfig := ctlr.Config.CopyAuthConfig()
if authConfig.IsBearerAuthEnabled() {
return bearerAuthHandler(ctlr)
}
return authnMiddleware.tryAuthnHandlers(ctlr)
}
func (amw *AuthnMiddleware) sessionAuthn(ctlr *Controller, userAc *reqCtx.UserAccessControl,
response http.ResponseWriter, request *http.Request,
) (bool, error) {
identity, ok := GetAuthUserFromRequestSession(ctlr.CookieStore, request, ctlr.Log)
if !ok {
// let the client know that this session is invalid/expired
cookie := &http.Cookie{ //nolint: gosec
Name: "session",
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
HttpOnly: true,
}
http.SetCookie(response, cookie)
return false, nil
}
userAc.SetUsername(identity)
userAc.SaveOnRequest(request)
groups, err := ctlr.MetaDB.GetUserGroups(request.Context())
if err != nil {
ctlr.Log.Err(err).Str("identity", identity).Msg("failed to get user profile in DB")
if errors.Is(err, zerr.ErrUserDataNotFound) {
// we handle this case as an authentication failure, not an internal server error
err = nil
}
return false, err
}
userAc.AddGroups(groups)
userAc.SaveOnRequest(request)
return true, nil
}
func (amw *AuthnMiddleware) mTLSAuthn(ctlr *Controller, userAc *reqCtx.UserAccessControl,
request *http.Request,
) (bool, error) {
// Check if mTLS is configured and client certificates are present
if request.TLS == nil || len(request.TLS.PeerCertificates) == 0 {
return false, nil
}
// Check if client certificate has verified chain
verifiedChains := request.TLS.VerifiedChains
if len(verifiedChains) == 0 || len(verifiedChains[0]) == 0 {
ctlr.Log.Debug().Msg("mTLS authentication failed - user provided certificate not signed by CA")
return false, nil
}
// Extract identity from certificate
leafCert := request.TLS.PeerCertificates[0]
// Get mTLS config from auth config
authConfig := ctlr.Config.CopyAuthConfig()
mtlsConfig := authConfig.GetMTLSConfig()
identity, err := extractMTLSIdentity(leafCert, mtlsConfig)
if err != nil || identity == "" {
ctlr.Log.Debug().Err(err).Msg("mTLS authentication failed - could not extract identity")
return false, nil
}
// Process request with mTLS identity
var groups []string
accessControl := ctlr.Config.CopyAccessControlConfig()
if accessControl != nil {
ac := NewAccessController(ctlr.Config)
groups = ac.getUserGroups(identity)
}
userAc.SetUsername(identity)
userAc.AddGroups(groups)
userAc.SaveOnRequest(request)
// Update user groups in MetaDB if available
if ctlr.MetaDB != nil {
if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil {
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("failed to update user profile")
return false, err
}
}
ctlr.Log.Debug().Str("identity", identity).Msg("mTLS authentication successful")
return true, nil
}
func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAccessControl,
response http.ResponseWriter, request *http.Request,
) (bool, error) {
cookieStore := ctlr.CookieStore
// Get auth config once to avoid multiple calls
authConfig := ctlr.Config.CopyAuthConfig()
if authConfig == nil {
return false, nil
}
identity, passphrase, err := getUsernamePasswordBasicAuth(request)
if err != nil {
ctlr.Log.Error().Err(err).Msg("failed to parse authorization header")
return false, nil
}
// first, HTTPPassword authN (which is local)
htOk, _ := amw.htpasswd.Authenticate(identity, passphrase)
if htOk {
// Process request
var groups []string
accessControl := ctlr.Config.CopyAccessControlConfig()
if accessControl != nil {
ac := NewAccessController(ctlr.Config)
groups = ac.getUserGroups(identity)
}
userAc.SetUsername(identity)
userAc.AddGroups(groups)
userAc.SaveOnRequest(request)
// saved logged session only if the request comes from web (has UI session header value)
if hasSessionHeader(request) {
secure := ctlr.Config.UseSecureSession()
if err := saveUserLoggedSession(cookieStore, response, request, identity, "", secure, ctlr.Log); err != nil {
return false, err
}
}
// we have already populated the request context with userAc
if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil {
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("failed to update user profile")
return false, err
}
ctlr.Log.Info().Str("identity", identity).Msgf("user profile successfully set")
return true, nil
}
// next, LDAP if configured (network-based which can lose connectivity)
if authConfig.IsLdapAuthEnabled() {
ok, _, ldapgroups, err := amw.ldapClient.Authenticate(identity, passphrase)
if ok && err == nil {
// Process request
var groups []string
accessControl := ctlr.Config.CopyAccessControlConfig()
if accessControl != nil {
ac := NewAccessController(ctlr.Config)
groups = ac.getUserGroups(identity)
}
groups = append(groups, ldapgroups...)
userAc.SetUsername(identity)
userAc.AddGroups(groups)
userAc.SaveOnRequest(request)
// saved logged session only if the request comes from web (has UI session header value)
if hasSessionHeader(request) {
secure := ctlr.Config.UseSecureSession()
if err := saveUserLoggedSession(cookieStore, response, request, identity, "", secure, ctlr.Log); err != nil {
return false, err
}
}
// we have already populated the request context with userAc
if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil {
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("failed to update user profile")
return false, err
}
return true, nil
}
}
// last try API keys
if authConfig.IsAPIKeyEnabled() {
apiKey := passphrase
if !strings.HasPrefix(apiKey, constants.APIKeysPrefix) {
ctlr.Log.Error().Msg("invalid api token format")
return false, nil
}
trimmedAPIKey := strings.TrimPrefix(apiKey, constants.APIKeysPrefix)
hashedKey := hashUUID(trimmedAPIKey)
storedIdentity, err := ctlr.MetaDB.GetUserAPIKeyInfo(hashedKey)
if err != nil {
if errors.Is(err, zerr.ErrUserAPIKeyNotFound) {
ctlr.Log.Info().Err(err).Msgf("failed to find any user info for hashed key %s in DB", hashedKey)
return false, nil
}
ctlr.Log.Error().Err(err).Msgf("failed to get user info for hashed key %s in DB", hashedKey)
return false, err
}
if storedIdentity == identity {
userAc.SetUsername(identity)
userAc.SaveOnRequest(request)
// check if api key expired
isExpired, err := ctlr.MetaDB.IsAPIKeyExpired(request.Context(), hashedKey)
if err != nil {
ctlr.Log.Err(err).Str("identity", identity).Msg("failed to verify if api key expired")
return false, err
}
if isExpired {
return false, nil
}
err = ctlr.MetaDB.UpdateUserAPIKeyLastUsed(request.Context(), hashedKey)
if err != nil {
ctlr.Log.Err(err).Str("identity", identity).Msg("failed to update user profile in DB")
return false, err
}
groups, err := ctlr.MetaDB.GetUserGroups(request.Context())
if err != nil {
ctlr.Log.Err(err).Str("identity", identity).Msg("failed to get user's groups in DB")
return false, err
}
userAc.AddGroups(groups)
userAc.SaveOnRequest(request)
return true, nil
}
}
return false, nil
}
func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFunc { //nolint: gocyclo
// Get auth config once to avoid multiple calls
authConfig := ctlr.Config.CopyAuthConfig()
// ldap and htpasswd based authN
if authConfig.IsLdapAuthEnabled() {
ldapConfig := authConfig.LDAP
ctlr.LDAPClient = &LDAPClient{
Host: ldapConfig.Address,
Port: ldapConfig.Port,
UseSSL: !ldapConfig.Insecure,
SkipTLS: !ldapConfig.StartTLS,
Base: ldapConfig.BaseDN,
BindDN: ldapConfig.BindDN(),
BindPassword: ldapConfig.BindPassword(),
UserGroupAttribute: ldapConfig.UserGroupAttribute, // from config
UserAttribute: ldapConfig.UserAttribute,
UserFilter: ldapConfig.UserFilter,
InsecureSkipVerify: ldapConfig.SkipVerify,
ServerName: ldapConfig.Address,
Log: ctlr.Log,
SubtreeSearch: ldapConfig.SubtreeSearch,
}
amw.ldapClient = ctlr.LDAPClient
if authConfig.LDAP.CACert != "" {
caCert, err := os.ReadFile(authConfig.LDAP.CACert)
if err != nil {
amw.log.Panic().Err(err).Str("caCert", authConfig.LDAP.CACert).
Msg("failed to read caCert")
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
amw.log.Panic().Err(zerr.ErrBadCACert).Str("caCert", authConfig.LDAP.CACert).
Msg("failed to read caCert")
}
amw.ldapClient.ClientCAs = caCertPool
} else {
// default to system cert pool
caCertPool, err := x509.SystemCertPool()
if err != nil {
amw.log.Panic().Err(zerr.ErrBadCACert).Str("caCert", authConfig.LDAP.CACert).
Msg("failed to get system cert pool")
}
amw.ldapClient.ClientCAs = caCertPool
}
}
if authConfig.IsHtpasswdAuthEnabled() {
err := amw.htpasswd.Reload(authConfig.HTPasswd.Path)
if err != nil {
amw.log.Panic().Err(err).Str("credsFile", authConfig.HTPasswd.Path).
Msg("failed to open creds-file")
}
}
// openid based authN
if authConfig.IsOpenIDAuthEnabled() {
ctlr.RelyingParties = make(map[string]rp.RelyingParty)
for provider := range authConfig.OpenID.Providers {
if config.IsOpenIDSupported(provider) {
rp := NewRelyingPartyOIDC(context.TODO(), ctlr.Config, provider, authConfig.SessionHashKey,
authConfig.SessionEncryptKey, ctlr.Log)
ctlr.RelyingParties[provider] = rp
} else if config.IsOauth2Supported(provider) {
rp := NewRelyingPartyGithub(ctlr.Config, provider, authConfig.SessionHashKey,
authConfig.SessionEncryptKey, ctlr.Log)
ctlr.RelyingParties[provider] = rp
}
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Method == http.MethodOptions {
next.ServeHTTP(response, request)
response.WriteHeader(http.StatusNoContent)
return
}
isMgmtRequested := request.RequestURI == constants.FullMgmt
isV2Requested := strings.TrimSuffix(request.URL.Path, "/") == constants.RoutePrefix
// Match Docker daemon-proxied requests regardless of the upstream client tool.
// The Docker daemon always prefixes its UA with "docker/<version>" when proxying,
// while the upstream tool (docker CLI, compose, buildx, etc.) appears inside
// "UpstreamClient(...)". Direct Docker CLI requests use "Docker-Client/...".
ua := request.Header.Get("User-Agent")
isDockerClient := strings.Contains(ua, "Docker-Client") || strings.HasPrefix(ua, "docker/")
// Get auth config safely
authConfig := ctlr.Config.CopyAuthConfig()
delay := authConfig.GetFailDelay()
realm := ctlr.Config.GetRealm()
// Get access control config safely
accessControlConfig := ctlr.Config.CopyAccessControlConfig()
allowAnonymous := accessControlConfig != nil && accessControlConfig.AnonymousPolicyExists()
// Allow anonymous access to the metrics endpoint only if configured.
extensionsConfig := ctlr.Config.CopyExtensionsConfig()
isMetricsRequestedWithAnonymousAccess := isAnonymousMetricsRequest(request, accessControlConfig, extensionsConfig)
// build user access control info
userAc := reqCtx.NewUserAccessControl()
// if it will not be populated by authn handlers, this represents an anonymous user
userAc.SaveOnRequest(request)
authenticated := false
var err error
// Switch authentication methods based on provided request context
switch {
// Reject requests with multiple Authorization headers as a security measure
case hasMultipleAuthorizationHeaders(request):
authenticated = false
// The authorization header presence is an explicit attempt to use basic authentication
case !isAuthorizationHeaderEmpty(request) && authConfig.IsBasicAuthnEnabled():
authenticated, err = amw.basicAuthn(ctlr, userAc, response, request)
// The session header is an explicit attempt to use session authentication
case hasSessionHeader(request):
authenticated, err = amw.sessionAuthn(ctlr, userAc, response, request)
if err != nil {
break
}
// If session authentication fails, but anonymous or management access is allowed,
// treat the request as authenticated. This fallback is necessary because the session
// header may be present for anonymous or management requests.
authenticated = authenticated || allowAnonymous || isMgmtRequested || isMetricsRequestedWithAnonymousAccess
// Try mTLS authentication if client certificates are present
case ctlr.Config.IsMTLSAuthEnabled() && request.TLS != nil && len(request.TLS.PeerCertificates) > 0:
authenticated, err = amw.mTLSAuthn(ctlr, userAc, request)
// If no auth methods enabled at all - then just authenticate anything
case !authConfig.IsBasicAuthnEnabled() && !ctlr.Config.IsMTLSAuthEnabled():
authenticated = true
// If no credentials provided - check for anonymous / mgmt / metrics requests
case allowAnonymous || isMgmtRequested || isMetricsRequestedWithAnonymousAccess:
// Docker workaround: force 401 on /v2/ when anonymous policies coexist with
// authenticated-only policies. Otherwise Docker treats 200 on /v2/ as "no auth"
// and will not send stored credentials for protected repositories.
// See: https://github.com/opencontainers/wg-auth/blob/main/docs/implementations/moby.md
hasMixedPolicy := accessControlConfig.HasMixedAnonymousAndAuthenticatedPolicies()
if isDockerClient && isV2Requested && hasMixedPolicy && authConfig.CanAuthenticateWithBasicCredentials() {
authenticated = false
} else {
authenticated = true
}
}
// If error occurred during authn process - return 500 error
if err != nil {
response.WriteHeader(http.StatusInternalServerError)
return
}
if authenticated {
next.ServeHTTP(response, request)
} else {
authFail(response, request, realm, delay)
}
})
}
}
func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
// Get auth config safely
authConfig := ctlr.Config.CopyAuthConfig()
var traditionalAuthorizerKeyFunc BearerAuthorizerKeyFunc
// Traditional bearer auth with public key/certificate
if authConfig.Bearer.Cert != "" {
// although the configuration option is called 'cert', this function will also parse a public key directly
// see https://github.com/project-zot/zot/issues/3173 for info
publicKey, err := loadPublicKeyFromFile(authConfig.Bearer.Cert)
if err != nil {
ctlr.Log.Panic().Err(err).Msg("failed to load public key for bearer authentication")
}
traditionalAuthorizerKeyFunc = func(_ context.Context, token *jwt.Token) (any, error) {
return publicKey, nil
}
}
// Traditional bearer auth with AWS Secrets Manager
if authConfig.Bearer.AWSSecretsManager != nil {
asmAuthz, err := NewAWSSecretsManager(
authConfig.Bearer.AWSSecretsManager, AWSSecretsManagerProviderImplementation{}, ctlr.Log)
if err != nil {
ctlr.Log.Panic().Err(err).Msg("failed to create AWS Secrets Manager key function for bearer authentication")
}
traditionalAuthorizerKeyFunc = asmAuthz.GetPublicKey
}
// Initialize authorizers based on configuration
var traditionalAuthorizer *BearerAuthorizer
if traditionalAuthorizerKeyFunc != nil {
traditionalAuthorizer = NewBearerAuthorizer(
authConfig.Bearer.Realm,
authConfig.Bearer.Service,
traditionalAuthorizerKeyFunc,
)
}
// OIDC bearer auth for workload identity
var oidcAuthorizer *OIDCBearerAuthorizer
if len(authConfig.Bearer.OIDC) > 0 {
var err error
oidcAuthorizer, err = NewOIDCBearerAuthorizer(authConfig.Bearer.OIDC, ctlr.Log)
if err != nil {
ctlr.Log.Panic().Err(err).Msg("failed to initialize OIDC bearer authorizer")
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Method == http.MethodOptions {
next.ServeHTTP(response, request)
response.WriteHeader(http.StatusNoContent)
return
}
// Reject requests with multiple Authorization headers as a security measure
if hasMultipleAuthorizationHeaders(request) {
ctlr.Log.Error().Msg("failed to parse Authorization header: multiple Authorization headers detected")
response.Header().Set("Content-Type", "application/json")
zcommon.WriteJSON(response, http.StatusUnauthorized, apiErr.NewError(apiErr.UNSUPPORTED))
return
}
acCtrlr := NewAccessController(ctlr.Config)
// we want to bypass auth for mgmt route
isMgmtRequested := request.RequestURI == constants.FullMgmt
header := request.Header.Get("Authorization")
if isAuthorizationHeaderEmpty(request) && isMgmtRequested {
next.ServeHTTP(response, request)
return
}
// Allow anonymous access to the metrics endpoint only if configured.
accessControlConfig := ctlr.Config.CopyAccessControlConfig()
extensionsConfig := ctlr.Config.CopyExtensionsConfig()
if isAnonymousMetricsRequest(request, accessControlConfig, extensionsConfig) {
next.ServeHTTP(response, request)
return
}
var requestedAccess *ResourceAction
if request.RequestURI != "/v2/" {
// if this is not the base route, the requested repository/action must be authorized
vars := mux.Vars(request)
name := vars["name"]
action := "pull"
if m := request.Method; m != http.MethodGet && m != http.MethodHead {
action = "push"
}
requestedAccess = &ResourceAction{
Type: "repository",
Name: name,
Action: action,
}
}
// Try OIDC authentication first if configured
if oidcAuthorizer != nil {
res, err := oidcAuthorizer.Authenticate(request.Context(), header)
if err == nil && res != nil && res.Username != "" {
// OIDC authentication succeeded
identity := res.Username
groups := res.Groups
ctlr.Log.Debug().Str("identity", identity).Msg("the OIDC bearer authentication was successful")
// Set user context for authorization
userAc := reqCtx.NewUserAccessControl()
userAc.SetUsername(identity)
userAc.AddGroups(groups)
userAc.SetClaims(res.Claims)
userAc.SaveOnRequest(request)
// Update user groups in MetaDB if available
if ctlr.MetaDB != nil {
if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil {
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("failed to update user profile")
response.WriteHeader(http.StatusInternalServerError)
return
}
}
// Use BEARER_OIDC to enable authorization via accessControl config.
// Unlike traditional bearer tokens (which contain 'access' claims with permissions),
// OIDC tokens contain identity only, so authorization must come from the config.
amCtx := acCtrlr.getAuthnMiddlewareContext(BEARER_OIDC, request)
next.ServeHTTP(response, request.WithContext(amCtx)) //nolint:contextcheck
return
}
}
// Fall back to traditional bearer token auth if OIDC didn't succeed
if traditionalAuthorizer != nil {
err := traditionalAuthorizer.Authorize(request.Context(), header, requestedAccess)
if err != nil {
var challenge *AuthChallengeError
if errors.As(err, &challenge) {
ctlr.Log.Debug().Err(challenge).Msg("bearer token authorization failed")
response.Header().Set("Content-Type", "application/json")
response.Header().Set("WWW-Authenticate", challenge.Header())
zcommon.WriteJSON(response, http.StatusUnauthorized, apiErr.NewError(apiErr.UNAUTHORIZED))
return
}
ctlr.Log.Error().Err(err).Msg("failed to parse Authorization header")
response.Header().Set("Content-Type", "application/json")
zcommon.WriteJSON(response, http.StatusUnauthorized, apiErr.NewError(apiErr.UNSUPPORTED))
return
}
amCtx := acCtrlr.getAuthnMiddlewareContext(BEARER, request)
next.ServeHTTP(response, request.WithContext(amCtx)) //nolint:contextcheck
return
}
// No authentication succeeded
if isAuthorizationHeaderEmpty(request) {
// No bearer token provided and no authentication method configured
ctlr.Log.Debug().Msg("no bearer token provided")
} else {
// Bearer token provided but authentication failed
ctlr.Log.Error().Msg("failed to authenticate with bearer token")
}
response.Header().Set("Content-Type", "application/json")
zcommon.WriteJSON(response, http.StatusUnauthorized, apiErr.NewError(apiErr.UNAUTHORIZED))
})
}
}
func canonicalOrigin(parsedURL *url.URL) (string, bool) {
if parsedURL == nil {
return "", false
}
scheme := strings.ToLower(parsedURL.Scheme)
if scheme != constants.SchemeHTTP && scheme != constants.SchemeHTTPS {
return "", false
}
host := strings.ToLower(parsedURL.Hostname())
if host == "" {
return "", false
}
port := parsedURL.Port()
if port == "" {
if scheme == constants.SchemeHTTP {
port = "80"
} else {
port = "443"
}
}
return scheme + "://" + net.JoinHostPort(host, port), true
}
func canonicalOriginString(raw string) (string, bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", false
}
parsed, err := url.Parse(raw)
if err != nil {
return "", false
}
// Only accept absolute http(s) URLs for allowlist entries.
if parsed.Scheme == "" || parsed.Host == "" {
return "", false
}
return canonicalOrigin(parsed)
}
// ValidateCallbackUI validates the callback_ui parameter used for post-login redirects.
// - Relative paths (starting with "/") are always allowed.
// - Absolute http(s) URLs are allowed only when their origin matches allowOrigins.
// It returns the validated redirect target, or "/" as fallback, or "" if the input is empty.
func ValidateCallbackUI(callbackUI string, allowOrigins []string) string {
if callbackUI == "" {
return ""
}
// Prevent header injection.
if strings.ContainsAny(callbackUI, "\r\n") {
return "/"
}
parsed, err := url.Parse(callbackUI)
if err != nil {
return "/"
}
// Reject protocol-relative URLs (e.g. //evil.com/path)
if strings.HasPrefix(callbackUI, "//") {
return "/"
}
// Relative path to root (safe default).
if parsed.Scheme == "" && parsed.Host == "" {
if !strings.HasPrefix(callbackUI, "/") {
return "/"
}
return callbackUI
}
// Absolute URL: only allow http(s) and only when origin is allowlisted.
if parsed.Scheme != constants.SchemeHTTP && parsed.Scheme != constants.SchemeHTTPS {
return "/"
}
if parsed.Host == "" {
return "/"
}
origin, ok := canonicalOrigin(parsed)
if !ok {
return "/"
}
for _, rawAllowed := range allowOrigins {
allowedOrigin, ok := canonicalOriginString(rawAllowed)
if !ok {
continue
}
if allowedOrigin == origin {
return callbackUI
}
}
return "/"
}
func (rh *RouteHandler) AuthURLHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
allowOrigins := []string{}
if authCfg := rh.c.Config.CopyAuthConfig(); authCfg != nil {
if authCfg.OpenID != nil {
allowOrigins = append(allowOrigins, authCfg.OpenID.CallbackAllowOrigins...)
}
}
// If an ExternalURL is configured, allow redirects back to that origin.
if rh.c.Config.HTTP.ExternalURL != "" {
allowOrigins = append(allowOrigins, rh.c.Config.HTTP.ExternalURL)
}
callbackUI := ValidateCallbackUI(query.Get(constants.CallbackUIQueryParam), allowOrigins)
provider := query.Get("provider")
client, ok := rh.c.RelyingParties[provider]
if !ok {
rh.c.Log.Error().Msg("failed to authenticate due to unrecognized openid provider")
w.WriteHeader(http.StatusBadRequest)
return
}
/* save cookie containing state to later verify it and
callback ui where we will redirect after openid/oauth2 logic is completed*/
session, _ := rh.c.CookieStore.Get(r, "statecookie")
session.Options.Secure = rh.c.Config.UseSecureSession()
session.Options.HttpOnly = true
session.Options.SameSite = http.SameSiteDefaultMode
session.Options.Path = constants.CallbackBasePath
state := uuid.New().String()
session.Values["state"] = state
session.Values["callback"] = callbackUI
// let the session set its own id
err := session.Save(r, w)
if err != nil {
rh.c.Log.Error().Err(err).Msg("failed to save http session")
w.WriteHeader(http.StatusInternalServerError)
return
}
stateFunc := func() string {
return state
}
rp.AuthURLHandler(stateFunc, client)(w, r)
}
}
func NewRelyingPartyOIDC(ctx context.Context, config *config.Config, provider string,
hashKey, encryptKey []byte, log log.Logger,
) rp.RelyingParty {
issuer, clientID, clientSecret, redirectURI, scopes, options := getRelyingPartyArgs(config,
provider, hashKey, encryptKey, log)
relyingParty, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, redirectURI, scopes, options...)
if err != nil {
log.Panic().Err(err).Str("issuer", issuer).Str("redirectURI", redirectURI).Strs("scopes", scopes).
Msg("failed to initialize new relying party oidc")
}
return relyingParty
}
func NewRelyingPartyGithub(config *config.Config, provider string, hashKey, encryptKey []byte, log log.Logger,
) rp.RelyingParty {
_, clientID, clientSecret, redirectURI, scopes,
options := getRelyingPartyArgs(config, provider, hashKey, encryptKey, log)
var endpoint oauth2.Endpoint
// Use custom endpoints if provided, otherwise fallback to GitHub's endpoints
if provider := config.HTTP.Auth.OpenID.Providers[provider]; provider.AuthURL != "" && provider.TokenURL != "" {
endpoint = oauth2.Endpoint{
AuthURL: provider.AuthURL,
TokenURL: provider.TokenURL,
}
} else {
endpoint = githubOAuth.Endpoint
}
rpConfig := &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURI,
Scopes: scopes,
Endpoint: endpoint,
}
relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...)
if err != nil {
log.Panic().Err(err).Str("redirectURI", redirectURI).Strs("scopes", scopes).
Msg("failed to initialize new relying party oauth")
}
return relyingParty
}
// originFromConfig returns the server's base URL (scheme + host[:port], no trailing
// slash) used as the origin for OIDC redirect URIs. It prefers ExternalURL; otherwise
// it derives the origin from cfg.HTTP.Address, cfg.HTTP.Port, and cfg.HTTP.TLS. Using
// a single source for both the login redirect_uri and the logout
// post_logout_redirect_uri ensures the IdP sees matching origins.
func originFromConfig(cfg *config.Config) string {
if trimmed := strings.TrimSuffix(cfg.HTTP.ExternalURL, "/"); trimmed != "" {
return trimmed
}
scheme := constants.SchemeHTTP
if cfg.HTTP.TLS != nil {
scheme = constants.SchemeHTTPS
}
return scheme + "://" + net.JoinHostPort(cfg.HTTP.Address, cfg.HTTP.Port)
}
func getRelyingPartyArgs(cfg *config.Config, provider string, hashKey, encryptKey []byte, log log.Logger) (
string, string, string, string, []string, []rp.Option,
) {
if _, ok := cfg.HTTP.Auth.OpenID.Providers[provider]; !ok {
log.Panic().Err(zerr.ErrOpenIDProviderDoesNotExist).Str("provider", provider).Msg("")
}
providerConfig := cfg.HTTP.Auth.OpenID.Providers[provider]
clientID := providerConfig.ClientID
clientSecret := providerConfig.ClientSecret
scopes := providerConfig.Scopes
// openid scope must be the first one in list
if !slices.Contains(scopes, oidc.ScopeOpenID) && config.IsOpenIDSupported(provider) {
scopes = append([]string{oidc.ScopeOpenID}, scopes...)
}
issuer := providerConfig.Issuer
keyPath := providerConfig.KeyPath
callback := constants.CallbackBasePath + "/" + provider
redirectURI := originFromConfig(cfg) + callback
options := []rp.Option{
rp.WithVerifierOpts(rp.WithIssuedAtOffset(issuedAtOffset)),
}
cookieHandler := httphelper.NewCookieHandler(hashKey, encryptKey, httphelper.WithMaxAge(relyingPartyCookieMaxAge))
options = append(options, rp.WithCookieHandler(cookieHandler))
if clientSecret == "" {
options = append(options, rp.WithPKCE(cookieHandler))
}
if keyPath != "" {
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
}
return issuer, clientID, clientSecret, redirectURI, scopes, options
}
func authFail(w http.ResponseWriter, r *http.Request, realm string, delay int) {
if !isAuthorizationHeaderEmpty(r) || hasSessionHeader(r) {
time.Sleep(time.Duration(delay) * time.Second)
}
// don't send auth headers if request is coming from UI
if r.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue {
if realm == "" {
realm = "Authorization Required"
}
realm = "Basic realm=" + strconv.Quote(realm)
w.Header().Set("WWW-Authenticate", realm)
}
w.Header().Set("Content-Type", "application/json")
zcommon.WriteJSON(w, http.StatusUnauthorized, apiErr.NewError(apiErr.UNAUTHORIZED))
}
func isAnonymousMetricsRequest(request *http.Request, accessControlConfig *config.AccessControlConfig,
extensionsConfig *extconf.ExtensionConfig,
) bool {
prometheusConfig := extensionsConfig.GetMetricsPrometheusConfig()
return isAuthorizationHeaderEmpty(request) &&
slices.Contains(accessControlConfig.GetMetrics().AnonymousPolicy, constants.ReadPermission) &&
extensionsConfig.IsMetricsEnabled() &&
prometheusConfig != nil &&
strings.HasPrefix(request.URL.Path, prometheusConfig.Path)
}
func isAuthorizationHeaderEmpty(request *http.Request) bool {
header := request.Header.Get("Authorization")
if header == "" || (strings.ToLower(header) == "basic og==") {
return true
}
return false
}
// hasMultipleAuthorizationHeaders checks if the request has multiple Authorization headers.
// This is a security concern as it could be used to bypass authentication or cause confusion.
func hasMultipleAuthorizationHeaders(request *http.Request) bool {
authHeaders := request.Header.Values("Authorization")
return len(authHeaders) > 1
}
func hasSessionHeader(request *http.Request) bool {
clientHeader := request.Header.Get(constants.SessionClientHeaderName)
return clientHeader == constants.SessionClientHeaderValue
}
func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error) {
basicAuth := request.Header.Get("Authorization")
if basicAuth == "" {
return "", "", zerr.ErrParsingAuthHeader
}
splitStr := strings.SplitN(basicAuth, " ", 2) //nolint:mnd
if len(splitStr) != 2 || strings.ToLower(splitStr[0]) != "basic" {
return "", "", zerr.ErrParsingAuthHeader
}
decodedStr, err := base64.StdEncoding.DecodeString(splitStr[1])
if err != nil {
return "", "", err
}
pair := strings.SplitN(string(decodedStr), ":", 2) //nolint:mnd
if len(pair) != 2 { //nolint:mnd
return "", "", zerr.ErrParsingAuthHeader
}
identity := pair[0]
passphrase := pair[1]
return identity, passphrase, nil
}
func GetGithubUserInfo(ctx context.Context, client *github.Client, log log.Logger) (string, []string, error) {
var primaryEmail string
userEmails, _, err := client.Users.ListEmails(ctx, nil)
if err != nil {
log.Error().Msg("failed to set user record for empty email value")
return "", []string{}, err
}
if len(userEmails) != 0 {
for _, email := range userEmails { // should have at least one primary email, if any
if email.GetPrimary() { // check if it's primary email
primaryEmail = email.GetEmail()
break
}
}
}
orgs, _, err := client.Organizations.List(ctx, "", nil)
if err != nil {
log.Error().Msg("failed to set user record for empty email value")
return "", []string{}, err
}
groups := []string{}
for _, org := range orgs {
groups = append(groups, *org.Login)
}
return primaryEmail, groups, nil
}
func saveUserLoggedSession(cookieStore sessions.Store, response http.ResponseWriter,
request *http.Request, identity, provider string, secure bool, log log.Logger,
) error {
session, _ := cookieStore.Get(request, "session")
session.Options.Secure = secure
session.Options.HttpOnly = true
session.Options.SameSite = http.SameSiteDefaultMode
session.Values["authStatus"] = true
session.Values["user"] = identity
if provider != "" {
session.Values["provider"] = provider
} else {
delete(session.Values, "provider")
}
// let the session set its own id
err := session.Save(request, response)
if err != nil {
log.Error().Err(err).Str("identity", identity).Msg("failed to save http session")
return err
}
userInfoCookie := sessions.NewCookie("user", identity, &sessions.Options{
Secure: secure,
HttpOnly: false,
MaxAge: cookiesMaxAge,
SameSite: http.SameSiteDefaultMode,
Path: "/",
})
http.SetCookie(response, userInfoCookie)
return nil
}
const (
defaultUsernameClaim = "email"
defaultGroupsClaim = "groups"
)
// getOpenIDClaimMapping resolves which OIDC claims supply identity (see ClaimMapping.Username)
// and groups for a given provider.
// The third return value reports whether the identity username claim was explicitly configured
// (false means the default "email" claim is being used).
func getOpenIDClaimMapping(authConfig *config.AuthConfig, providerName string) (string, string, bool) {
identityClaim := defaultUsernameClaim
groupsClaim := defaultGroupsClaim
identityConfigured := false
if authConfig == nil || authConfig.OpenID == nil || providerName == "" {
return identityClaim, groupsClaim, identityConfigured
}
providerConfig, ok := authConfig.OpenID.Providers[providerName]
if !ok || providerConfig.ClaimMapping == nil {
return identityClaim, groupsClaim, identityConfigured
}
if providerConfig.ClaimMapping.Username != "" {
identityClaim = providerConfig.ClaimMapping.Username
identityConfigured = true
}
if providerConfig.ClaimMapping.Groups != "" {
groupsClaim = providerConfig.ClaimMapping.Groups
}
return identityClaim, groupsClaim, identityConfigured
}
func getOpenIDIdentity(info *oidc.UserInfo, claimName string) string {
if info == nil {
return ""
}
switch claimName {
case "preferred_username":
return info.PreferredUsername
case defaultUsernameClaim:
return info.UserInfoEmail.Email
case "sub":
return info.Subject
case "name":
return info.Name
default:
if val, ok := info.Claims[claimName].(string); ok {
return val
}
}
return ""
}
func appendOpenIDGroups(groups []string, claims map[string]any, claimName string) ([]string, bool) {
switch val := claims[claimName].(type) {
case []any:
for _, group := range val {
if group == nil {
continue
}
if str := fmt.Sprint(group); str != "" {
groups = append(groups, str)
}
}
return groups, true
case []string:
for _, group := range val {
if group != "" {
groups = append(groups, group)
}
}
return groups, true
case string:
if val != "" {
groups = append(groups, val)
}
return groups, true
}
return groups, false
}
// extractOpenIDIdentity resolves identity and groups for an OIDC callback
// based on the provider's configured claim mapping. It returns the resolved
// identity string, the deduplicated/sorted groups, and a boolean reporting whether
// identity could be resolved at all (false means callers should reject).
func extractOpenIDIdentity(logger log.Logger, authConfig *config.AuthConfig, providerName string,
info *oidc.UserInfo, idTokenClaims map[string]any,
) (string, []string, bool) {
identityClaim, groupsClaim, identityConfigured := getOpenIDClaimMapping(authConfig, providerName)
identity := getOpenIDIdentity(info, identityClaim)
fellBackToDefaultClaim := false
if identity == "" && identityConfigured && identityClaim != defaultUsernameClaim {
fellBackToDefaultClaim = true
configuredClaim := identityClaim
identityClaim = defaultUsernameClaim
identity = getOpenIDIdentity(info, identityClaim)
logger.Warn().
Str("provider", providerName).
Str("claim", configuredClaim).
Msgf("configured username claim missing or empty, falling back to %q claim", defaultUsernameClaim)
}
if identity == "" {
return "", nil, false
}
logger.Debug().
Str("provider", providerName).
Str("claim", identityClaim).
Str("identity", identity).
Bool("fellBackToDefaultClaim", fellBackToDefaultClaim).
Msg("extracted identity")
var groups []string
if info != nil {
groups, _ = appendOpenIDGroups(groups, info.Claims, groupsClaim)
}
if idTokenClaims != nil {
groups, _ = appendOpenIDGroups(groups, idTokenClaims, groupsClaim)
}
slices.Sort(groups)
groups = slices.Compact(groups)
if len(groups) == 0 {
logger.Debug().
Str("provider", providerName).
Str("groupsClaim", groupsClaim).
Str("identity", identity).
Msg("no groups claim values found in UserInfo or ID token claims")
}
return identity, groups, true
}
// OAuth2Callback is the callback logic where openid/oauth2 will redirect back to our app.
func OAuth2Callback(ctlr *Controller, w http.ResponseWriter, r *http.Request, state, identity, provider string,
groups []string,
) (string, error) {
stateCookie, _ := ctlr.CookieStore.Get(r, "statecookie")
stateOrigin, ok := stateCookie.Values["state"].(string)
if !ok {
ctlr.Log.Error().Err(zerr.ErrInvalidStateCookie).Str("component", "openID").
Msg("failed to get 'state' cookie from request")
return "", zerr.ErrInvalidStateCookie
}
if stateOrigin != state {
ctlr.Log.Error().Err(zerr.ErrInvalidStateCookie).Str("component", "openID").
Msg("'state' cookie differs from the actual one")
return "", zerr.ErrInvalidStateCookie
}
userAc := reqCtx.NewUserAccessControl()
userAc.SetUsername(identity)
userAc.AddGroups(groups)
userAc.SaveOnRequest(r)
// if this line has been reached, then a new session should be created
// if the `session` key is already on the cookie, it's not a valid one
secure := ctlr.Config.UseSecureSession()
if err := saveUserLoggedSession(ctlr.CookieStore, w, r, identity, provider, secure, ctlr.Log); err != nil {
return "", err
}
if err := ctlr.MetaDB.SetUserGroups(r.Context(), groups); err != nil {
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("failed to update the user profile")
return "", err
}
ctlr.Log.Info().Str("identity", identity).Msg("user profile set successfully")
// redirect to UI
callbackUI, _ := stateCookie.Values["callback"].(string)
return callbackUI, nil
}
func hashUUID(uuid string) string {
digester := sha256.New()
digester.Write([]byte(uuid))
return godigest.NewDigestFromEncoded(godigest.SHA256, hex.EncodeToString(digester.Sum(nil))).Encoded()
}
/*
GetAuthUserFromRequestSession returns the authenticated user identifier
and auth status if on the request's cookie session is a logged in user.
*/
func GetAuthUserFromRequestSession(cookieStore sessions.Store, request *http.Request, log log.Logger,
) (string, bool) {
session, err := cookieStore.Get(request, "session")
if err != nil {
log.Error().Err(err).Msg("failed to decode existing session")
// expired cookie, no need to return err
return "", false
}
// at this point we should have a session set on cookie.
// if created in the earlier Get() call then user is not logged in with sessions.
if session.IsNew {
return "", false
}
authenticated := session.Values["authStatus"]
if authenticated != true {
log.Error().Msg("failed to get `user` session value")
return "", false
}
identity, ok := session.Values["user"].(string)
if !ok {
log.Error().Msg("failed to get `user` session value")
return "", false
}
return identity, true
}
func GenerateAPIKey(uuidGenerator guuid.Generator, log log.Logger,
) (string, string, error) {
apiKeyBase, err := uuidGenerator.NewV4()
if err != nil {
log.Error().Err(err).Msg("failed to generate uuid for api key base")
return "", "", err
}
apiKey := strings.ReplaceAll(apiKeyBase.String(), "-", "")
// will be used for identifying a specific api key
apiKeyID, err := uuidGenerator.NewV4()
if err != nil {
log.Error().Err(err).Msg("failed to generate uuid for api key id")
return "", "", err
}
return apiKey, apiKeyID.String(), err
}
// extractIdentityFromCertificate attempts to extract identity from a specific identity attribute.
func extractIdentityFromCertificate(cert *x509.Certificate, identityAttribute string, mtlsConfig *config.MTLSConfig,
) (string, error) {
// Normalize to lowercase for case-insensitive matching
normalizedIdentityAttribute := strings.ToLower(strings.TrimSpace(identityAttribute))
switch normalizedIdentityAttribute {
case "commonname", "cn":
if cert.Subject.CommonName == "" {
return "", zerr.ErrNoIdentityInCommonName
}
return cert.Subject.CommonName, nil
case "subject", "dn":
return cert.Subject.String(), nil
case "url", "uri":
if len(cert.URIs) == 0 {
return "", zerr.ErrNoURISANFound
}
idx := 0
if mtlsConfig != nil {
idx = mtlsConfig.URISANIndex
}
if idx < 0 || idx >= len(cert.URIs) {
return "", fmt.Errorf("%w: %d", zerr.ErrURISANIndexOutOfRange, idx)
}
uri := cert.URIs[idx].String()
// Apply pattern if specified
if mtlsConfig != nil && mtlsConfig.URISANPattern != "" {
re, err := regexp.Compile(mtlsConfig.URISANPattern)
if err != nil {
return "", fmt.Errorf("%w: %w", zerr.ErrInvalidURISANPattern, err)
}
matches := re.FindStringSubmatch(uri)
if len(matches) < 2 {
return "", fmt.Errorf("%w", zerr.ErrURISANPatternDidNotMatch)
}
return matches[1], nil // Return first capture group
}
return uri, nil
case "dnsname", "dns":
if len(cert.DNSNames) == 0 {
return "", zerr.ErrNoDNSANFound
}
idx := 0
if mtlsConfig != nil {
idx = mtlsConfig.DNSANIndex
}
if idx < 0 || idx >= len(cert.DNSNames) {
return "", fmt.Errorf("%w: %d", zerr.ErrDNSANIndexOutOfRange, idx)
}
return cert.DNSNames[idx], nil
case "email", "rfc822name":
if len(cert.EmailAddresses) == 0 {
return "", zerr.ErrNoEmailSANFound
}
idx := 0
if mtlsConfig != nil {
idx = mtlsConfig.EmailSANIndex
}
if idx < 0 || idx >= len(cert.EmailAddresses) {
return "", fmt.Errorf("%w: %d", zerr.ErrEmailSANIndexOutOfRange, idx)
}
return cert.EmailAddresses[idx], nil
default:
return "", fmt.Errorf("%w: %s", zerr.ErrUnsupportedIdentityAttribute, identityAttribute)
}
}
// extractMTLSIdentity extracts identity from certificate using configured identity attributes with fallback chain.
func extractMTLSIdentity(cert *x509.Certificate, mtlsConfig *config.MTLSConfig) (string, error) {
identityAttributes := []string{"CommonName"} // Default
if mtlsConfig != nil && len(mtlsConfig.IdentityAttibutes) > 0 {
identityAttributes = mtlsConfig.IdentityAttibutes
}
var cummulatedErr error
for _, identityAttribute := range identityAttributes {
identity, err := extractIdentityFromCertificate(cert, identityAttribute, mtlsConfig)
if err == nil {
return identity, nil
}
cummulatedErr = errors.Join(cummulatedErr, err)
}
return "", fmt.Errorf("no identity found in any configured identity attributes: %w", cummulatedErr)
}
func loadPublicKeyFromFile(path string) (crypto.PublicKey, error) {
raw, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("%w: %w, path %s", zerr.ErrCouldNotLoadPublicKey, err, path)
}
return loadPublicKeyFromBytes(raw)
}
func loadPublicKeyFromBytes(raw []byte) (crypto.PublicKey, error) {
var keySet jose.JSONWebKeySet
if err := json.Unmarshal(raw, &keySet); err == nil {
if len(keySet.Keys) != 1 {
return nil, fmt.Errorf("%w: expected 1 key in JWKS, found %d", zerr.ErrCouldNotLoadPublicKey, len(keySet.Keys))
}
return keySet.Keys[0].Key, nil
}
block, _ := pem.Decode(raw)
if block == nil {
return nil, fmt.Errorf("%w: no valid PEM data found", zerr.ErrCouldNotLoadPublicKey)
}
pemBytes := block.Bytes
if cert, err := x509.ParseCertificate(pemBytes); err == nil {
return cert.PublicKey, nil
}
if key, err := x509.ParsePKIXPublicKey(pemBytes); err == nil {
return key, nil
}
if key, err := x509.ParsePKCS1PublicKey(pemBytes); err == nil {
return key, nil
}
return nil, fmt.Errorf("%w: no valid x509 certificate or public key found", zerr.ErrCouldNotLoadPublicKey)
}