feat: integrate openID auth logic and user profile management (#1381)

This change introduces OpenID authn by using providers such as Github,
Gitlab, Google and Dex.
User sessions are now used for web clients to identify
and persist an authenticated users session, thus not requiring every request to
use credentials.
Another change is apikey feature, users can create/revoke their api keys and use them
to authenticate when using cli clients such as skopeo.

eg:
login:
/auth/login?provider=github
/auth/login?provider=gitlab
and so on

logout:
/auth/logout

redirectURL:
/auth/callback/github
/auth/callback/gitlab
and so on

If network policy doesn't allow inbound connections, this callback wont work!

for more info read documentation added in this commit.

Signed-off-by: Alex Stan <alexandrustan96@yahoo.ro>
Signed-off-by: Petu Eusebiu <peusebiu@cisco.com>
Co-authored-by: Alex Stan <alexandrustan96@yahoo.ro>
This commit is contained in:
peusebiu
2023-07-07 19:27:10 +03:00
committed by GitHub
parent 5494a1b8d6
commit 17d1338af1
51 changed files with 5467 additions and 624 deletions
+678 -163
View File
@@ -3,139 +3,315 @@ package api
import (
"bufio"
"context"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"net"
"net/http"
"os"
"path"
"strconv"
"strings"
"time"
"github.com/chartmuseum/auth"
"github.com/google/go-github/v52/github"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
godigest "github.com/opencontainers/go-digest"
"github.com/zitadel/oidc/pkg/client/rp"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
githubOAuth "golang.org/x/oauth2/github"
"zotregistry.io/zot/errors"
zerr "zotregistry.io/zot/errors"
"zotregistry.io/zot/pkg/api/config"
"zotregistry.io/zot/pkg/api/constants"
apiErr "zotregistry.io/zot/pkg/api/errors"
"zotregistry.io/zot/pkg/common"
"zotregistry.io/zot/pkg/log"
localCtx "zotregistry.io/zot/pkg/requestcontext"
storageConstants "zotregistry.io/zot/pkg/storage/constants"
)
const (
bearerAuthDefaultAccessEntryType = "repository"
issuedAtOffset = 5 * time.Second
relyingPartyCookieMaxAge = 120
)
func AuthHandler(c *Controller) mux.MiddlewareFunc {
if isBearerAuthEnabled(c.Config) {
return bearerAuthHandler(c)
}
return basicAuthHandler(c)
type AuthnMiddleware struct {
credMap map[string]string
ldapClient *LDAPClient
}
func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
authorizer, err := auth.NewAuthorizer(&auth.AuthorizerOptions{
Realm: ctlr.Config.HTTP.Auth.Bearer.Realm,
Service: ctlr.Config.HTTP.Auth.Bearer.Service,
PublicKeyPath: ctlr.Config.HTTP.Auth.Bearer.Cert,
AccessEntryType: bearerAuthDefaultAccessEntryType,
EmptyDefaultNamespace: true,
})
func AuthHandler(ctlr *Controller) mux.MiddlewareFunc {
authnMiddleware := &AuthnMiddleware{}
if isBearerAuthEnabled(ctlr.Config) {
return bearerAuthHandler(ctlr)
}
return authnMiddleware.TryAuthnHandlers(ctlr)
}
func (amw *AuthnMiddleware) sessionAuthn(ctlr *Controller, next http.Handler, response http.ResponseWriter,
request *http.Request, delay int,
) {
clientHeader := request.Header.Get(constants.SessionClientHeaderName)
if clientHeader != constants.SessionClientHeaderValue {
authFail(response, request, ctlr.Config.HTTP.Realm, delay)
return
}
identity, ok := common.GetAuthUserFromRequestSession(ctlr.CookieStore, request, ctlr.Log)
if !ok {
// let the client know that this session is invalid/expired
cookie := &http.Cookie{
Name: "session",
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
HttpOnly: true,
}
http.SetCookie(response, cookie)
authFail(response, request, ctlr.Config.HTTP.Realm, delay)
return
}
ctx := getReqContextWithAuthorization(identity, []string{}, request)
groups, err := ctlr.RepoDB.GetUserGroups(ctx)
if err != nil {
ctlr.Log.Panic().Err(err).Msg("error creating bearer authorizer")
if errors.Is(err, zerr.ErrUserDataNotFound) {
ctlr.Log.Err(err).Str("identity", identity).Msg("can not find user profile in DB")
authFail(response, request, ctlr.Config.HTTP.Realm, delay)
return
}
ctlr.Log.Err(err).Str("identity", identity).Msg("can not get user profile in DB")
response.WriteHeader(http.StatusInternalServerError)
return
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Method == http.MethodOptions {
next.ServeHTTP(response, request)
response.WriteHeader(http.StatusNoContent)
ctx = getReqContextWithAuthorization(identity, groups, request)
return
}
vars := mux.Vars(request)
name := vars["name"]
// we want to bypass auth for mgmt route
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
header := request.Header.Get("Authorization")
if (header == "" || header == "Basic Og==") && isMgmtRequested {
next.ServeHTTP(response, request)
return
}
action := auth.PullAction
if m := request.Method; m != http.MethodGet && m != http.MethodHead {
action = auth.PushAction
}
permissions, err := authorizer.Authorize(header, action, name)
if err != nil {
ctlr.Log.Error().Err(err).Msg("issue parsing Authorization header")
response.Header().Set("Content-Type", "application/json")
common.WriteJSON(response, http.StatusInternalServerError, apiErr.NewErrorList(apiErr.NewError(apiErr.UNSUPPORTED)))
return
}
if !permissions.Allowed {
authFail(response, permissions.WWWAuthenticateHeader, 0)
return
}
next.ServeHTTP(response, request)
})
}
next.ServeHTTP(response, request.WithContext(ctx))
}
func noPasswdAuth(realm string, config *config.Config) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Method == http.MethodOptions {
next.ServeHTTP(response, request)
response.WriteHeader(http.StatusNoContent)
func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, response http.ResponseWriter,
request *http.Request,
) (bool, http.ResponseWriter, *http.Request, error) {
cookieStore := ctlr.CookieStore
return
}
// we want to bypass auth for mgmt route
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
// Process request
if request.Header.Get("Authorization") == "" {
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
ctx := getReqContextWithAuthorization("", []string{}, request)
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
})
// Process request
return true, response, request.WithContext(ctx), nil
}
}
identity, passphrase, err := getUsernamePasswordBasicAuth(request)
if err != nil {
ctlr.Log.Error().Err(err).Msg("failed to parse authorization header")
return false, nil, nil, nil
}
// some client tools might send Authorization: Basic Og== (decoded into ":")
// empty username and password
if identity == "" && passphrase == "" {
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
ctx := getReqContextWithAuthorization("", []string{}, request)
return true, response, request.WithContext(ctx), nil
}
}
passphraseHash, ok := amw.credMap[identity]
if ok {
// first, HTTPPassword authN (which is local)
if err := bcrypt.CompareHashAndPassword([]byte(passphraseHash), []byte(passphrase)); err == nil {
// Process request
var groups []string
if ctlr.Config.HTTP.AccessControl != nil {
ac := NewAccessController(ctlr.Config)
groups = ac.getUserGroups(identity)
}
ctx := getReqContextWithAuthorization(identity, groups, request)
// saved logged session
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
return false, response, request, err
}
if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil {
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("couldn't update user profile")
return false, response, request, err
}
ctlr.Log.Info().Str("identity", identity).Msgf("user profile successfully set")
return true, response, request.WithContext(ctx), nil
}
}
// next, LDAP if configured (network-based which can lose connectivity)
if ctlr.Config.HTTP.Auth != nil && ctlr.Config.HTTP.Auth.LDAP != nil {
ok, _, ldapgroups, err := amw.ldapClient.Authenticate(identity, passphrase)
if ok && err == nil {
// Process request
var groups []string
if ctlr.Config.HTTP.AccessControl != nil {
ac := NewAccessController(ctlr.Config)
groups = ac.getUserGroups(identity)
}
groups = append(groups, ldapgroups...)
ctx := getReqContextWithAuthorization(identity, groups, request)
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
return false, response, request, err
}
if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil {
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("couldn't update user profile")
return false, response, request, err
}
return true, response, request.WithContext(ctx), nil
}
}
// last try API keys
if isAPIKeyEnabled(ctlr.Config) {
apiKey := passphrase
if !strings.HasPrefix(apiKey, constants.APIKeysPrefix) {
ctlr.Log.Error().Msg("api token has invalid format")
return false, nil, nil, nil
}
trimmedAPIKey := strings.TrimPrefix(apiKey, constants.APIKeysPrefix)
hashedKey := hashUUID(trimmedAPIKey)
storedIdentity, err := ctlr.RepoDB.GetUserAPIKeyInfo(hashedKey)
if err != nil {
if errors.Is(err, zerr.ErrUserAPIKeyNotFound) {
ctlr.Log.Info().Err(err).Msgf("can not find any user info for hashed key %s in DB", hashedKey)
return false, nil, nil, nil
}
ctlr.Log.Error().Err(err).Msgf("can not get user info for hashed key %s in DB", hashedKey)
return false, nil, nil, err
}
if storedIdentity == identity {
ctx := getReqContextWithAuthorization(identity, []string{}, request)
err := ctlr.RepoDB.UpdateUserAPIKeyLastUsed(ctx, hashedKey)
if err != nil {
ctlr.Log.Err(err).Str("identity", identity).Msg("can not update user profile in DB")
return false, nil, nil, err
}
groups, err := ctlr.RepoDB.GetUserGroups(ctx)
if err != nil {
ctlr.Log.Err(err).Str("identity", identity).Msg("can not get user's groups in DB")
return false, nil, nil, err
}
ctx = getReqContextWithAuthorization(identity, groups, request)
return true, response, request.WithContext(ctx), nil
}
}
return false, nil, nil, nil
}
//nolint:gocyclo // we use closure making this a complex subroutine
func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
realm := ctlr.Config.HTTP.Realm
if realm == "" {
realm = "Authorization Required"
}
realm = "Basic realm=" + strconv.Quote(realm)
func (amw *AuthnMiddleware) TryAuthnHandlers(ctlr *Controller) mux.MiddlewareFunc { //nolint: gocyclo
// no password based authN, if neither LDAP nor HTTP BASIC is enabled
if ctlr.Config.HTTP.Auth == nil ||
(ctlr.Config.HTTP.Auth.HTPasswd.Path == "" && ctlr.Config.HTTP.Auth.LDAP == nil) {
return noPasswdAuth(realm, ctlr.Config)
(ctlr.Config.HTTP.Auth.HTPasswd.Path == "" && ctlr.Config.HTTP.Auth.LDAP == nil &&
ctlr.Config.HTTP.Auth.OpenID == nil) {
return noPasswdAuth(ctlr.Config)
}
credMap := make(map[string]string)
amw.credMap = make(map[string]string)
delay := ctlr.Config.HTTP.Auth.FailDelay
var ldapClient *LDAPClient
// setup sessions cookie store used to preserve logged in user in web sessions
if isAuthnEnabled(ctlr.Config) || isOpenIDAuthEnabled(ctlr.Config) {
// To store custom types in our cookies,
// we must first register them using gob.Register
gob.Register(map[string]interface{}{})
cookieStoreHashKey := securecookie.GenerateRandomKey(64)
if cookieStoreHashKey == nil {
panic(zerr.ErrHashKeyNotCreated)
}
// if storage is filesystem then use zot's rootDir to store sessions
if ctlr.Config.Storage.StorageDriver == nil {
sessionsDir := path.Join(ctlr.Config.Storage.RootDirectory, "_sessions")
if err := os.MkdirAll(sessionsDir, storageConstants.DefaultDirPerms); err != nil {
panic(err)
}
cookieStore := sessions.NewFilesystemStore(sessionsDir, cookieStoreHashKey)
cookieStore.MaxAge(cookiesMaxAge)
ctlr.CookieStore = cookieStore
} else {
cookieStore := sessions.NewCookieStore(cookieStoreHashKey)
cookieStore.MaxAge(cookiesMaxAge)
ctlr.CookieStore = cookieStore
}
}
// ldap and htpasswd based authN
if ctlr.Config.HTTP.Auth != nil {
if ctlr.Config.HTTP.Auth.LDAP != nil {
ldapConfig := ctlr.Config.HTTP.Auth.LDAP
ldapClient = &LDAPClient{
amw.ldapClient = &LDAPClient{
Host: ldapConfig.Address,
Port: ldapConfig.Port,
UseSSL: !ldapConfig.Insecure,
@@ -160,18 +336,18 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
panic(errors.ErrBadCACert)
panic(zerr.ErrBadCACert)
}
ldapClient.ClientCAs = caCertPool
amw.ldapClient.ClientCAs = caCertPool
} else {
// default to system cert pool
caCertPool, err := x509.SystemCertPool()
if err != nil {
panic(errors.ErrBadCACert)
panic(zerr.ErrBadCACert)
}
ldapClient.ClientCAs = caCertPool
amw.ldapClient.ClientCAs = caCertPool
}
}
@@ -188,12 +364,27 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
line := scanner.Text()
if strings.Contains(line, ":") {
tokens := strings.Split(scanner.Text(), ":")
credMap[tokens[0]] = tokens[1]
amw.credMap[tokens[0]] = tokens[1]
}
}
}
}
// openid based authN
if ctlr.Config.HTTP.Auth.OpenID != nil {
ctlr.RelyingParties = make(map[string]rp.RelyingParty)
for provider := range ctlr.Config.HTTP.Auth.OpenID.Providers {
if IsOpenIDSupported(provider) {
rp := NewRelyingPartyOIDC(ctlr.Config, provider)
ctlr.RelyingParties[provider] = rp
} else if IsOauth2Supported(provider) {
rp := NewRelyingPartyGithub(ctlr.Config, provider)
ctlr.RelyingParties[provider] = rp
}
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Method == http.MethodOptions {
@@ -203,84 +394,231 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
return
}
// we want to bypass auth for mgmt route
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
if request.Header.Get("Authorization") == "" {
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
// Process request
ctx := getReqContextWithAuthorization("", []string{}, request)
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
return
}
}
username, passphrase, err := getUsernamePasswordBasicAuth(request)
//nolint: contextcheck
authenticated, cloneResp, cloneReq, err := amw.basicAuthn(ctlr, response, request)
if err != nil {
ctlr.Log.Error().Err(err).Msg("failed to parse authorization header")
authFail(response, realm, delay)
response.WriteHeader(http.StatusInternalServerError)
return
}
// some client tools might send Authorization: Basic Og== (decoded into ":")
// empty username and password
if username == "" && passphrase == "" {
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
// Process request
ctx := getReqContextWithAuthorization("", []string{}, request)
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
if authenticated && cloneResp != nil && cloneReq != nil {
next.ServeHTTP(cloneResp, cloneReq)
return
}
return
}
// first, HTTPPassword authN (which is local)
passphraseHash, ok := credMap[username]
if ok {
if err := bcrypt.CompareHashAndPassword([]byte(passphraseHash), []byte(passphrase)); err == nil {
// Process request
var userGroups []string
if ctlr.Config.HTTP.AccessControl != nil {
ac := NewAccessController(ctlr.Config)
userGroups = ac.getUserGroups(username)
}
ctx := getReqContextWithAuthorization(username, userGroups, request)
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
return
}
}
// next, LDAP if configured (network-based which can lose connectivity)
if ctlr.Config.HTTP.Auth != nil && ctlr.Config.HTTP.Auth.LDAP != nil {
ok, _, ldapgroups, err := ldapClient.Authenticate(username, passphrase)
if ok && err == nil {
// Process request
var userGroups []string
if ctlr.Config.HTTP.AccessControl != nil {
ac := NewAccessController(ctlr.Config)
userGroups = ac.getUserGroups(username)
}
userGroups = append(userGroups, ldapgroups...)
ctx := getReqContextWithAuthorization(username, userGroups, request)
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
return
}
}
authFail(response, realm, delay)
//nolint: contextcheck
amw.sessionAuthn(ctlr, next, response, request, delay)
})
}
}
func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
authorizer, err := auth.NewAuthorizer(&auth.AuthorizerOptions{
Realm: ctlr.Config.HTTP.Auth.Bearer.Realm,
Service: ctlr.Config.HTTP.Auth.Bearer.Service,
PublicKeyPath: ctlr.Config.HTTP.Auth.Bearer.Cert,
AccessEntryType: bearerAuthDefaultAccessEntryType,
EmptyDefaultNamespace: true,
})
if err != nil {
ctlr.Log.Panic().Err(err).Msg("error creating bearer authorizer")
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Method == http.MethodOptions {
next.ServeHTTP(response, request)
response.WriteHeader(http.StatusNoContent)
return
}
acCtrlr := NewAccessController(ctlr.Config)
vars := mux.Vars(request)
name := vars["name"]
// we want to bypass auth for mgmt route
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
header := request.Header.Get("Authorization")
if (header == "" || header == "Basic Og==") && isMgmtRequested {
next.ServeHTTP(response, request)
return
}
action := auth.PullAction
if m := request.Method; m != http.MethodGet && m != http.MethodHead {
action = auth.PushAction
}
permissions, err := authorizer.Authorize(header, action, name)
if err != nil {
ctlr.Log.Error().Err(err).Msg("issue parsing Authorization header")
response.Header().Set("Content-Type", "application/json")
common.WriteJSON(response, http.StatusInternalServerError, apiErr.NewErrorList(apiErr.NewError(apiErr.UNSUPPORTED)))
return
}
if !permissions.Allowed {
response.Header().Set("Content-Type", "application/json")
response.Header().Set("WWW-Authenticate", permissions.WWWAuthenticateHeader)
common.WriteJSON(response, http.StatusUnauthorized,
apiErr.NewErrorList(apiErr.NewError(apiErr.UNAUTHORIZED)))
return
}
amCtx := acCtrlr.getAuthnMiddlewareContext(BEARER, request)
next.ServeHTTP(response, request.WithContext(amCtx)) //nolint:contextcheck
})
}
}
func noPasswdAuth(config *config.Config) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Method == http.MethodOptions {
next.ServeHTTP(response, request)
response.WriteHeader(http.StatusNoContent)
return
}
ctx := getReqContextWithAuthorization("", []string{}, request)
// Process request
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
})
}
}
func (rh *RouteHandler) AuthURLHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
callbackUI := query.Get(constants.CallbackUIQueryParam)
provider := query.Get("provider")
client, ok := rh.c.RelyingParties[provider]
if !ok {
http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
response.WriteHeader(http.StatusBadRequest)
})(w, r)
}
/* save cookie containing state to later verify it and
callback ui where we will redirect after openid/oauth2 logic is completed*/
session, _ := rh.c.CookieStore.Get(r, "statecookie")
session.Options.Secure = true
session.Options.HttpOnly = true
session.Options.SameSite = http.SameSiteDefaultMode
session.Options.Path = constants.CallbackBasePath
state := uuid.New().String()
session.Values["state"] = state
session.Values["callback"] = callbackUI
// let the session set its own id
err := session.Save(r, w)
if err != nil {
rh.c.Log.Error().Err(err).Msg("unable to save http session")
w.WriteHeader(http.StatusInternalServerError)
return
}
stateFunc := func() string {
return state
}
rp.AuthURLHandler(stateFunc, client)(w, r)
}
}
func NewRelyingPartyOIDC(config *config.Config, provider string) rp.RelyingParty {
issuer, clientID, clientSecret, redirectURI, scopes, options := getRelyingPartyArgs(config, provider)
relyingParty, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...)
if err != nil {
panic(err)
}
return relyingParty
}
func NewRelyingPartyGithub(config *config.Config, provider string) rp.RelyingParty {
_, clientID, clientSecret, redirectURI, scopes, options := getRelyingPartyArgs(config, provider)
rpConfig := &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURI,
Scopes: scopes,
Endpoint: githubOAuth.Endpoint,
}
relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...)
if err != nil {
panic(err)
}
return relyingParty
}
func getRelyingPartyArgs(config *config.Config, provider string) (
string, string, string, string, []string, []rp.Option,
) {
if _, ok := config.HTTP.Auth.OpenID.Providers[provider]; !ok {
panic(zerr.ErrOpenIDProviderDoesNotExist)
}
scheme := "http"
if config.HTTP.TLS != nil {
scheme = "https"
}
clientID := config.HTTP.Auth.OpenID.Providers[provider].ClientID
clientSecret := config.HTTP.Auth.OpenID.Providers[provider].ClientSecret
scopes := config.HTTP.Auth.OpenID.Providers[provider].Scopes
// openid scope must be the first one in list
if !common.Contains(scopes, oidc.ScopeOpenID) && IsOpenIDSupported(provider) {
scopes = append([]string{oidc.ScopeOpenID}, scopes...)
}
port := config.HTTP.Port
issuer := config.HTTP.Auth.OpenID.Providers[provider].Issuer
keyPath := config.HTTP.Auth.OpenID.Providers[provider].KeyPath
baseURL := net.JoinHostPort(config.HTTP.Address, port)
redirectURI := fmt.Sprintf("%s://%s%s", scheme, baseURL, constants.CallbackBasePath+fmt.Sprintf("/%s", provider))
options := []rp.Option{
rp.WithVerifierOpts(rp.WithIssuedAtOffset(issuedAtOffset)),
}
key := securecookie.GenerateRandomKey(32) //nolint: gomnd
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithMaxAge(relyingPartyCookieMaxAge))
options = append(options, rp.WithCookieHandler(cookieHandler))
if clientSecret == "" {
options = append(options, rp.WithPKCE(cookieHandler))
}
if keyPath != "" {
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
}
return issuer, clientID, clientSecret, redirectURI, scopes, options
}
func getReqContextWithAuthorization(username string, groups []string, request *http.Request) context.Context {
acCtx := localCtx.AccessControlContext{
Username: username,
@@ -314,9 +652,71 @@ func isBearerAuthEnabled(config *config.Config) bool {
return false
}
func authFail(w http.ResponseWriter, realm string, delay int) {
func isOpenIDAuthEnabled(config *config.Config) bool {
if config.HTTP.Auth != nil &&
config.HTTP.Auth.OpenID != nil {
for provider := range config.HTTP.Auth.OpenID.Providers {
if isOpenIDAuthProviderEnabled(config, provider) {
return true
}
}
}
return false
}
func isAPIKeyEnabled(config *config.Config) bool {
if config.Extensions != nil && config.Extensions.APIKey != nil &&
*config.Extensions.APIKey.Enable {
return true
}
return false
}
func isOpenIDAuthProviderEnabled(config *config.Config, provider string) bool {
if providerConfig, ok := config.HTTP.Auth.OpenID.Providers[provider]; ok {
if IsOpenIDSupported(provider) {
if providerConfig.ClientID != "" || providerConfig.Issuer != "" ||
len(providerConfig.Scopes) > 0 {
return true
}
} else if IsOauth2Supported(provider) {
if providerConfig.ClientID != "" || len(providerConfig.Scopes) > 0 {
return true
}
}
}
return false
}
func IsOpenIDSupported(provider string) bool {
supported := []string{"google", "gitlab", "dex"}
return common.Contains(supported, provider)
}
func IsOauth2Supported(provider string) bool {
supported := []string{"github"}
return common.Contains(supported, provider)
}
func authFail(w http.ResponseWriter, r *http.Request, realm string, delay int) {
time.Sleep(time.Duration(delay) * time.Second)
w.Header().Set("WWW-Authenticate", realm)
// don't send auth headers if request is coming from UI
if r.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue {
if realm == "" {
realm = "Authorization Required"
}
realm = "Basic realm=" + strconv.Quote(realm)
w.Header().Set("WWW-Authenticate", realm)
}
w.Header().Set("Content-Type", "application/json")
common.WriteJSON(w, http.StatusUnauthorized, apiErr.NewErrorList(apiErr.NewError(apiErr.UNAUTHORIZED)))
}
@@ -325,12 +725,12 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error)
basicAuth := request.Header.Get("Authorization")
if basicAuth == "" {
return "", "", errors.ErrParsingAuthHeader
return "", "", zerr.ErrParsingAuthHeader
}
splitStr := strings.SplitN(basicAuth, " ", 2) //nolint:gomnd
splitStr := strings.SplitN(basicAuth, " ", 2) //nolint: gomnd
if len(splitStr) != 2 || strings.ToLower(splitStr[0]) != "basic" {
return "", "", errors.ErrParsingAuthHeader
return "", "", zerr.ErrParsingAuthHeader
}
decodedStr, err := base64.StdEncoding.DecodeString(splitStr[1])
@@ -338,9 +738,9 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error)
return "", "", err
}
pair := strings.SplitN(string(decodedStr), ":", 2) //nolint:gomnd
if len(pair) != 2 { //nolint:gomnd
return "", "", errors.ErrParsingAuthHeader
pair := strings.SplitN(string(decodedStr), ":", 2) //nolint: gomnd
if len(pair) != 2 { //nolint: gomnd
return "", "", zerr.ErrParsingAuthHeader
}
username := pair[0]
@@ -348,3 +748,118 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error)
return username, passphrase, nil
}
func GetGithubUserInfo(ctx context.Context, client *github.Client, log log.Logger) (string, []string, error) {
var primaryEmail string
userEmails, _, err := client.Users.ListEmails(ctx, nil)
if err != nil {
log.Error().Msg("couldn't set user record for empty email value")
return "", []string{}, err
}
if len(userEmails) != 0 {
for _, email := range userEmails { // should have at least one primary email, if any
if email.GetPrimary() { // check if it's primary email
primaryEmail = email.GetEmail()
break
}
}
}
orgs, _, err := client.Organizations.List(ctx, "", nil)
if err != nil {
log.Error().Msg("couldn't set user record for empty email value")
return "", []string{}, err
}
groups := []string{}
for _, org := range orgs {
groups = append(groups, *org.Login)
}
return primaryEmail, groups, nil
}
func saveUserLoggedSession(cookieStore sessions.Store, response http.ResponseWriter,
request *http.Request, identity string, log log.Logger,
) error {
session, _ := cookieStore.Get(request, "session")
session.Options.Secure = true
session.Options.HttpOnly = true
session.Options.SameSite = http.SameSiteDefaultMode
session.Values["authStatus"] = true
session.Values["user"] = identity
// let the session set its own id
err := session.Save(request, response)
if err != nil {
log.Error().Err(err).Str("identity", identity).Msg("unable to save http session")
return err
}
userInfoCookie := sessions.NewCookie("user", identity, &sessions.Options{
Secure: true,
HttpOnly: false,
MaxAge: cookiesMaxAge,
SameSite: http.SameSiteDefaultMode,
Path: "/",
})
http.SetCookie(response, userInfoCookie)
return nil
}
// OAuth2Callback is the callback logic where openid/oauth2 will redirect back to our app.
func OAuth2Callback(ctlr *Controller, w http.ResponseWriter, r *http.Request, state, email string,
groups []string,
) (string, error) {
stateCookie, _ := ctlr.CookieStore.Get(r, "statecookie")
stateOrigin, ok := stateCookie.Values["state"].(string)
if !ok {
ctlr.Log.Error().Err(zerr.ErrInvalidStateCookie).Msg("openID: unable to get 'state' cookie from request")
return "", zerr.ErrInvalidStateCookie
}
if stateOrigin != state {
ctlr.Log.Error().Err(zerr.ErrInvalidStateCookie).Msg("openID: 'state' cookie differs from the actual one")
return "", zerr.ErrInvalidStateCookie
}
ctx := getReqContextWithAuthorization(email, groups, r)
// if this line has been reached, then a new session should be created
// if the `session` key is already on the cookie, it's not a valid one
if err := saveUserLoggedSession(ctlr.CookieStore, w, r, email, ctlr.Log); err != nil {
return "", err
}
if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil {
ctlr.Log.Error().Err(err).Str("identity", email).Msg("couldn't update the user profile")
return "", err
}
ctlr.Log.Info().Msgf("user profile set successfully for email %s", email)
// redirect to UI
callbackUI, _ := stateCookie.Values["callback"].(string)
return callbackUI, nil
}
func hashUUID(uuid string) string {
digester := sha256.New()
digester.Write([]byte(uuid))
return godigest.NewDigestFromEncoded(godigest.SHA256, fmt.Sprintf("%x", digester.Sum(nil))).Encoded()
}
+45 -11
View File
@@ -21,6 +21,9 @@ const (
Delete = "delete"
// behaviour actions.
DetectManifestCollision = "detectManifestCollision"
BASIC = "Basic"
BEARER = "Bearer"
OPENID = "OpenID"
)
// AccessController authorizes users to act on resources.
@@ -29,10 +32,17 @@ type AccessController struct {
Log log.Logger
}
func NewAccessController(config *config.Config) *AccessController {
func NewAccessController(conf *config.Config) *AccessController {
if conf.HTTP.AccessControl == nil {
return &AccessController{
Config: &config.AccessControlConfig{},
Log: log.NewLogger(conf.Log.Level, conf.Log.Output),
}
}
return &AccessController{
Config: config.HTTP.AccessControl,
Log: log.NewLogger(config.Log.Level, config.Log.Output),
Config: conf.HTTP.AccessControl,
Log: log.NewLogger(conf.Log.Level, conf.Log.Output),
}
}
@@ -171,6 +181,18 @@ func (ac *AccessController) getContext(acCtx *localCtx.AccessControlContext, req
return ctx
}
// getAuthnMiddlewareContext builds ac context(allowed to read repos and if user is admin) and returns it.
func (ac *AccessController) getAuthnMiddlewareContext(authnType string, request *http.Request) context.Context {
amwCtx := localCtx.AuthnMiddlewareContext{
AuthnType: authnType,
}
amwCtxKey := localCtx.GetAuthnMiddlewareCtxKey()
ctx := context.WithValue(request.Context(), amwCtxKey, amwCtx)
return ctx
}
// isPermitted returns true if username can do action on a repository policy.
func (ac *AccessController) isPermitted(userGroups []string, username, action string,
policyGroup config.PolicyGroup,
@@ -231,6 +253,14 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
return
}
// request comes from bearer authn, bypass it
authnMwCtx, err := localCtx.GetAuthnMiddlewareContext(request.Context())
if err != nil || (authnMwCtx != nil && authnMwCtx.AuthnType == BEARER) {
next.ServeHTTP(response, request)
return
}
// bypass authz for /v2/ route
if request.RequestURI == "/v2/" {
next.ServeHTTP(response, request)
@@ -242,8 +272,6 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
var identity string
var err error
// anonymous context
acCtx := &localCtx.AccessControlContext{}
@@ -252,7 +280,7 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
// get access control context made in authn.go if authn is enabled
acCtx, err = localCtx.GetAccessControlContext(request.Context())
if err != nil { // should never happen
authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
return
}
@@ -272,7 +300,7 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
// if we still don't have an identity
if identity == "" {
acCtrlr.Log.Info().Msg("couldn't get identity from TLS certificate")
authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
return
}
@@ -298,6 +326,14 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
return
}
// request comes from bearer authn, bypass it
authnMwCtx, err := localCtx.GetAuthnMiddlewareContext(request.Context())
if err != nil || (authnMwCtx != nil && authnMwCtx.AuthnType == BEARER) {
next.ServeHTTP(response, request)
return
}
vars := mux.Vars(request)
resource := vars["name"]
reference, ok := vars["reference"]
@@ -306,12 +342,10 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
var identity string
var err error
// get acCtx built in authn and previous authz middlewares
acCtx, err := localCtx.GetAccessControlContext(request.Context())
if err != nil { // should never happen
authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
return
}
@@ -344,7 +378,7 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
can := acCtrlr.can(request.Context(), identity, action, resource) //nolint:contextcheck
if !can {
common.AuthzFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
common.AuthzFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
} else {
next.ServeHTTP(response, request) //nolint:contextcheck
}
+14
View File
@@ -45,6 +45,7 @@ type AuthConfig struct {
HTPasswd AuthHTPasswd
LDAP *LDAPConfig
Bearer *BearerConfig
OpenID *OpenIDConfig
}
type BearerConfig struct {
@@ -53,6 +54,18 @@ type BearerConfig struct {
Cert string
}
type OpenIDConfig struct {
Providers map[string]OpenIDProviderConfig
}
type OpenIDProviderConfig struct {
ClientID string
ClientSecret string
KeyPath string
Issuer string
Scopes []string
}
type MethodRatelimitConfig struct {
Method string
Rate int
@@ -63,6 +76,7 @@ type RatelimitConfig struct {
Methods []MethodRatelimitConfig `mapstructure:",omitempty"`
}
//nolint:maligned
type HTTPConfig struct {
Address string
Port string
+7
View File
@@ -12,4 +12,11 @@ const (
DefaultMediaType = "application/json"
BinaryMediaType = "application/octet-stream"
DefaultMetricsExtensionRoute = "/metrics"
CallbackBasePath = "/auth/callback"
LoginPath = "/auth/login"
LogoutPath = "/auth/logout"
SessionClientHeaderName = "X-ZOT-API-CLIENT"
SessionClientHeaderValue = "zot-ui"
APIKeysPrefix = "zak_"
CallbackUIQueryParam = "callback_ui"
)
+3
View File
@@ -18,4 +18,7 @@ const (
ExtUserPreferences = "/userprefs"
ExtUserPreferencesPrefix = ExtPrefix + ExtUserPreferences
FullUserPreferencesPrefix = RoutePrefix + ExtUserPreferencesPrefix
ExtAPIKey = "/apikey"
ExtAPIKeyPrefix = ExtPrefix + ExtAPIKey //nolint: gosec
FullAPIKeyPrefix = RoutePrefix + ExtAPIKeyPrefix
)
+8 -1
View File
@@ -16,6 +16,8 @@ import (
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/zitadel/oidc/pkg/client/rp"
"zotregistry.io/zot/errors"
"zotregistry.io/zot/pkg/api/config"
@@ -31,6 +33,7 @@ import (
const (
idleTimeout = 120 * time.Second
readHeaderTimeout = 5 * time.Second
cookiesMaxAge = 86400 // seconds
)
type Controller struct {
@@ -44,6 +47,8 @@ type Controller struct {
Metrics monitoring.MetricServer
CveInfo ext.CveInfo
SyncOnDemand SyncOnDemand
RelyingParties map[string]rp.RelyingParty
CookieStore sessions.Store
// runtime params
chosenPort int // kernel-chosen port
}
@@ -254,7 +259,9 @@ func (c *Controller) InitImageStore() error {
}
func (c *Controller) InitRepoDB(reloadCtx context.Context) error {
if c.Config.Extensions != nil && c.Config.Extensions.Search != nil && *c.Config.Extensions.Search.Enable {
// init repoDB if search is enabled or authn enabled (need to store user profiles) or apikey ext is enabled
if (c.Config.Extensions != nil && c.Config.Extensions.Search != nil && *c.Config.Extensions.Search.Enable) ||
isAuthnEnabled(c.Config) || isOpenIDAuthEnabled(c.Config) || isAPIKeyEnabled(c.Config) {
driver, err := repodbfactory.New(c.Config.Storage.StorageConfig, c.Log) //nolint:contextcheck
if err != nil {
return err
+1213 -26
View File
File diff suppressed because it is too large Load Diff
+164 -18
View File
@@ -20,11 +20,14 @@ import (
"strconv"
"strings"
"github.com/google/go-github/v52/github"
"github.com/gorilla/mux"
"github.com/opencontainers/distribution-spec/specs-go/v1/extensions"
godigest "github.com/opencontainers/go-digest"
ispec "github.com/opencontainers/image-spec/specs-go/v1"
artifactspec "github.com/oras-project/artifacts-spec/specs-go/v1"
"github.com/zitadel/oidc/pkg/client/rp"
"github.com/zitadel/oidc/pkg/oidc"
zerr "zotregistry.io/zot/errors"
"zotregistry.io/zot/pkg/api/constants"
@@ -55,13 +58,38 @@ func NewRouteHandler(c *Controller) *RouteHandler {
}
func (rh *RouteHandler) SetupRoutes() {
// first get Auth middleware in order to first setup openid/ldap/htpasswd, before oidc provider routes are setup
authHandler := AuthHandler(rh.c)
applyCORSHeaders := getCORSHeadersHandler(rh.c.Config.HTTP.AllowOrigin)
if isOpenIDAuthEnabled(rh.c.Config) {
// login path for openID
rh.c.Router.HandleFunc(constants.LoginPath, rh.AuthURLHandler())
// logout path for openID
rh.c.Router.HandleFunc(constants.LogoutPath, applyCORSHeaders(rh.Logout)).
Methods(zcommon.AllowedMethods("POST")...)
// callback path for openID
for provider, relyingParty := range rh.c.RelyingParties {
if IsOauth2Supported(provider) {
rh.c.Router.HandleFunc(constants.CallbackBasePath+fmt.Sprintf("/%s", provider),
rp.CodeExchangeHandler(rh.GithubCodeExchangeCallback(), relyingParty))
} else if IsOpenIDSupported(provider) {
rh.c.Router.HandleFunc(constants.CallbackBasePath+fmt.Sprintf("/%s", provider),
rp.CodeExchangeHandler(rp.UserinfoCallback(rh.OpenIDCodeExchangeCallback()), relyingParty))
}
}
}
prefixedRouter := rh.c.Router.PathPrefix(constants.RoutePrefix).Subrouter()
prefixedRouter.Use(AuthHandler(rh.c))
prefixedRouter.Use(authHandler)
prefixedDistSpecRouter := prefixedRouter.NewRoute().Subrouter()
// authz is being enabled if AccessControl is specified
// if Authn is not present AccessControl will have only default policies
if rh.c.Config.HTTP.AccessControl != nil && !isBearerAuthEnabled(rh.c.Config) {
if rh.c.Config.HTTP.AccessControl != nil {
if isAuthnEnabled(rh.c.Config) {
rh.c.Log.Info().Msg("access control is being enabled")
} else {
@@ -72,8 +100,6 @@ func (rh *RouteHandler) SetupRoutes() {
prefixedDistSpecRouter.Use(DistSpecAuthzHandler(rh.c))
}
applyCORSHeaders := getCORSHeadersHandler(rh.c.Config.HTTP.AllowOrigin)
// https://github.com/opencontainers/distribution-spec/blob/main/spec.md#endpoints
{
prefixedDistSpecRouter.HandleFunc(fmt.Sprintf("/{name:%s}/tags/list", zreg.NameRegexp.String()),
@@ -118,7 +144,7 @@ func (rh *RouteHandler) SetupRoutes() {
constants.ArtifactSpecRoutePrefix, zreg.NameRegexp.String()), rh.GetOrasReferrers).Methods("GET")
// swagger
debug.SetupSwaggerRoutes(rh.c.Config, rh.c.Router, AuthHandler(rh.c), rh.c.Log)
debug.SetupSwaggerRoutes(rh.c.Config, rh.c.Router, authHandler, rh.c.Log)
// Setup Extensions Routes
if rh.c.Config != nil {
@@ -135,8 +161,8 @@ func (rh *RouteHandler) SetupRoutes() {
rh.c.Log)
ext.SetupUserPreferencesRoutes(rh.c.Config, prefixedExtensionsRouter, rh.c.StoreController, rh.c.RepoDB,
rh.c.CveInfo, rh.c.Log)
ext.SetupMetricsRoutes(rh.c.Config, rh.c.Router, rh.c.StoreController, AuthHandler(rh.c), rh.c.Log)
ext.SetupAPIKeyRoutes(rh.c.Config, prefixedExtensionsRouter, rh.c.RepoDB, rh.c.CookieStore, rh.c.Log)
ext.SetupMetricsRoutes(rh.c.Config, rh.c.Router, rh.c.StoreController, authHandler, rh.c.Log)
gqlPlayground.SetupGQLPlaygroundRoutes(rh.c.Config, prefixedRouter, rh.c.StoreController, rh.c.Log)
@@ -185,7 +211,8 @@ func addCORSHeaders(allowOrigin string, response http.ResponseWriter) {
// @Success 200 {string} string "ok".
func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
@@ -195,10 +222,13 @@ func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, reques
// NOTE: compatibility workaround - return this header in "allowed-read" mode to allow for clients to
// work correctly
if rh.c.Config.HTTP.Auth != nil {
if rh.c.Config.HTTP.Auth.Bearer != nil {
response.Header().Set("WWW-Authenticate", fmt.Sprintf("bearer realm=%s", rh.c.Config.HTTP.Auth.Bearer.Realm))
} else {
response.Header().Set("WWW-Authenticate", fmt.Sprintf("basic realm=%s", rh.c.Config.HTTP.Realm))
// don't send auth headers if request is coming from UI
if request.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue {
if rh.c.Config.HTTP.Auth.Bearer != nil {
response.Header().Set("WWW-Authenticate", fmt.Sprintf("bearer realm=%s", rh.c.Config.HTTP.Auth.Bearer.Realm))
} else {
response.Header().Set("WWW-Authenticate", fmt.Sprintf("basic realm=%s", rh.c.Config.HTTP.Realm))
}
}
}
@@ -224,7 +254,8 @@ type ImageTags struct {
// @Failure 400 {string} string "bad request".
func (rh *RouteHandler) ListTags(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
@@ -355,7 +386,8 @@ func (rh *RouteHandler) ListTags(response http.ResponseWriter, request *http.Req
// @Failure 500 {string} string "internal server error".
func (rh *RouteHandler) CheckManifest(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
@@ -427,7 +459,8 @@ type ExtensionList struct {
// @Router /v2/{name}/manifests/{reference} [get].
func (rh *RouteHandler) GetManifest(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
@@ -527,7 +560,8 @@ func getReferrers(routeHandler *RouteHandler,
// @Router /v2/{name}/referrers/{digest} [get].
func (rh *RouteHandler) GetReferrers(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
@@ -1576,7 +1610,8 @@ type RepositoryList struct {
// @Router /v2/_catalog [get].
func (rh *RouteHandler) ListRepositories(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
@@ -1642,7 +1677,8 @@ func (rh *RouteHandler) ListRepositories(response http.ResponseWriter, request *
// @Router /v2/_oci/ext/discover [get].
func (rh *RouteHandler) ListExtensions(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
w.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
w.Header().Set("Access-Control-Allow-Credentials", "true")
if r.Method == http.MethodOptions {
return
@@ -1653,6 +1689,116 @@ func (rh *RouteHandler) ListExtensions(w http.ResponseWriter, r *http.Request) {
zcommon.WriteJSON(w, http.StatusOK, extensionList)
}
// The following routes are specific to zot and NOT part of the OCI dist-spec
// Logout godoc
// @Summary Logout by removing current session
// @Description Logout by removing current session
// @Router /openid/auth/logout [post]
// @Accept json
// @Produce json
// @Success 200 {string} string "ok".
// @Failure 500 {string} string "internal server error".
func (rh *RouteHandler) Logout(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
}
session, _ := rh.c.CookieStore.Get(request, "session")
session.Options.MaxAge = -1
err := session.Save(request, response)
if err != nil {
response.WriteHeader(http.StatusInternalServerError)
return
}
response.WriteHeader(http.StatusOK)
}
// github Oauth2 CodeExchange callback.
func (rh *RouteHandler) GithubCodeExchangeCallback() rp.CodeExchangeCallback {
return func(w http.ResponseWriter, r *http.Request,
tokens *oidc.Tokens, state string, relyingParty rp.RelyingParty,
) {
ctx := r.Context()
client := github.NewClient(relyingParty.OAuthConfig().Client(ctx, tokens.Token))
email, groups, err := GetGithubUserInfo(ctx, client, rh.c.Log)
if email == "" || err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
callbackUI, err := OAuth2Callback(rh.c, w, r, state, email, groups) //nolint: contextcheck
if err != nil {
if errors.Is(err, zerr.ErrInvalidStateCookie) {
w.WriteHeader(http.StatusUnauthorized)
}
w.WriteHeader(http.StatusInternalServerError)
}
if callbackUI != "" {
http.Redirect(w, r, callbackUI, http.StatusFound)
return
}
w.WriteHeader(http.StatusCreated)
}
}
// Openid CodeExchange callback.
func (rh *RouteHandler) OpenIDCodeExchangeCallback() rp.CodeExchangeUserinfoCallback {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string,
relyingParty rp.RelyingParty, info oidc.UserInfo,
) {
email := info.GetEmail()
if email == "" {
rh.c.Log.Error().Msg("couldn't set user record for empty email value")
w.WriteHeader(http.StatusUnauthorized)
return
}
var groups []string
val, ok := info.GetClaim("groups").([]interface{})
if !ok {
rh.c.Log.Info().Msgf("couldn't find any 'groups' claim for user %s", email)
}
for _, group := range val {
groups = append(groups, fmt.Sprint(group))
}
callbackUI, err := OAuth2Callback(rh.c, w, r, state, email, groups)
if err != nil {
if errors.Is(err, zerr.ErrInvalidStateCookie) {
w.WriteHeader(http.StatusUnauthorized)
}
w.WriteHeader(http.StatusInternalServerError)
}
if callbackUI != "" {
http.Redirect(w, r, callbackUI, http.StatusFound)
return
}
w.WriteHeader(http.StatusCreated)
}
}
func (rh *RouteHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
m := rh.c.Metrics.ReceiveMetrics()
zcommon.WriteJSON(w, http.StatusOK, m)
+175 -4
View File
@@ -1,26 +1,36 @@
//go:build sync && scrub && metrics && search && lint
// +build sync,scrub,metrics,search,lint
//go:build sync && scrub && metrics && search && lint && apikey
// +build sync,scrub,metrics,search,lint,apikey
package api_test
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/google/uuid"
"github.com/gorilla/mux"
godigest "github.com/opencontainers/go-digest"
ispec "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/project-zot/mockoidc"
. "github.com/smartystreets/goconvey/convey"
"github.com/zitadel/oidc/pkg/client/rp"
"github.com/zitadel/oidc/pkg/oidc"
"golang.org/x/oauth2"
zerr "zotregistry.io/zot/errors"
"zotregistry.io/zot/pkg/api"
"zotregistry.io/zot/pkg/api/config"
"zotregistry.io/zot/pkg/api/constants"
"zotregistry.io/zot/pkg/extensions"
extconf "zotregistry.io/zot/pkg/extensions/config"
"zotregistry.io/zot/pkg/meta/repodb"
localCtx "zotregistry.io/zot/pkg/requestcontext"
storageTypes "zotregistry.io/zot/pkg/storage/types"
"zotregistry.io/zot/pkg/test"
@@ -29,6 +39,8 @@ import (
var ErrUnexpectedError = errors.New("error: unexpected error")
const sessionStr = "session"
func TestRoutes(t *testing.T) {
Convey("Make a new controller", t, func() {
port := test.GetFreePort()
@@ -36,6 +48,45 @@ func TestRoutes(t *testing.T) {
conf := config.New()
conf.HTTP.Port = port
htpasswdPath := test.MakeHtpasswdFile()
defer os.Remove(htpasswdPath)
mockOIDCServer, err := mockoidc.Run()
if err != nil {
panic(err)
}
defer func() {
err := mockOIDCServer.Shutdown()
if err != nil {
panic(err)
}
}()
mockOIDCConfig := mockOIDCServer.Config()
conf.HTTP.Auth = &config.AuthConfig{
HTPasswd: config.AuthHTPasswd{
Path: htpasswdPath,
},
OpenID: &config.OpenIDConfig{
Providers: map[string]config.OpenIDProviderConfig{
"dex": {
ClientID: mockOIDCConfig.ClientID,
ClientSecret: mockOIDCConfig.ClientSecret,
KeyPath: "",
Issuer: mockOIDCConfig.Issuer,
Scopes: []string{"openid", "email"},
},
},
},
}
defaultVal := true
apiKeyConfig := &extconf.APIKeyConfig{
BaseConfig: extconf.BaseConfig{Enable: &defaultVal},
}
conf.Extensions = &extconf.ExtensionConfig{
APIKey: apiKeyConfig,
}
ctlr := api.NewController(conf)
ctlr.Config.Storage.RootDirectory = t.TempDir()
@@ -50,6 +101,52 @@ func TestRoutes(t *testing.T) {
// NOTE: the url or method itself doesn't matter below since we are calling the handlers directly,
// so path routing is bypassed
Convey("Test GithubCodeExchangeCallback", func() {
callback := rthdlr.GithubCodeExchangeCallback()
ctx := context.TODO()
request, _ := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
response := httptest.NewRecorder()
tokens := &oidc.Tokens{}
relyingParty, err := rp.NewRelyingPartyOAuth(&oauth2.Config{})
So(err, ShouldBeNil)
callback(response, request, tokens, "state", relyingParty)
resp := response.Result()
defer resp.Body.Close()
So(resp, ShouldNotBeNil)
So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized)
})
Convey("Test OAuth2Callback errors", func() {
ctx := context.TODO()
request, _ := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
response := httptest.NewRecorder()
_, err := api.OAuth2Callback(ctlr, response, request, "state", "email", []string{"group"})
So(err, ShouldEqual, zerr.ErrInvalidStateCookie)
session, _ := ctlr.CookieStore.Get(request, "statecookie")
session.Options.Secure = true
session.Options.HttpOnly = true
session.Options.SameSite = http.SameSiteDefaultMode
state := uuid.New().String()
session.Values["state"] = state
// let the session set its own id
err = session.Save(request, response)
So(err, ShouldBeNil)
_, err = api.OAuth2Callback(ctlr, response, request, "state", "email", []string{"group"})
So(err, ShouldEqual, zerr.ErrInvalidStateCookie)
})
Convey("List repositories authz error", func() {
var invalid struct{}
@@ -575,7 +672,7 @@ func TestRoutes(t *testing.T) {
},
&mocks.MockedImageStore{
FullBlobUploadFn: func(repo string, body io.Reader, digest godigest.Digest) (string, int64, error) {
return "session", 0, zerr.ErrBadBlobDigest
return sessionStr, 0, zerr.ErrBadBlobDigest
},
})
So(statusCode, ShouldEqual, http.StatusInternalServerError)
@@ -591,7 +688,7 @@ func TestRoutes(t *testing.T) {
},
&mocks.MockedImageStore{
FullBlobUploadFn: func(repo string, body io.Reader, digest godigest.Digest) (string, int64, error) {
return "session", 20, nil
return sessionStr, 20, nil
},
})
So(statusCode, ShouldEqual, http.StatusInternalServerError)
@@ -1327,6 +1424,80 @@ func TestRoutes(t *testing.T) {
So(resp.StatusCode, ShouldEqual, http.StatusOK)
})
Convey("Test API keys", func() {
var invalid struct{}
ctx := context.TODO()
key := localCtx.GetContextKey()
ctx = context.WithValue(ctx, key, invalid)
request, _ := http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader([]byte{}))
response := httptest.NewRecorder()
extensions.CreateAPIKey(response, request, ctlr.RepoDB, ctlr.CookieStore, ctlr.Log)
resp := response.Result()
defer resp.Body.Close()
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
acCtx := localCtx.AccessControlContext{
Username: username,
}
ctx = context.TODO()
key = localCtx.GetContextKey()
ctx = context.WithValue(ctx, key, acCtx)
request, _ = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader([]byte{}))
response = httptest.NewRecorder()
extensions.CreateAPIKey(response, request, ctlr.RepoDB, ctlr.CookieStore, ctlr.Log)
resp = response.Result()
defer resp.Body.Close()
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
payload := extensions.APIKeyPayload{
Label: "test",
Scopes: []string{"test"},
}
reqBody, err := json.Marshal(payload)
So(err, ShouldBeNil)
request, _ = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(reqBody))
response = httptest.NewRecorder()
extensions.CreateAPIKey(response, request, mocks.RepoDBMock{
AddUserAPIKeyFn: func(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error {
return ErrUnexpectedError
},
}, ctlr.CookieStore, ctlr.Log)
resp = response.Result()
defer resp.Body.Close()
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
request, _ = http.NewRequestWithContext(ctx, http.MethodDelete, baseURL, bytes.NewReader([]byte{}))
response = httptest.NewRecorder()
q := request.URL.Query()
q.Add("id", "apikeyid")
request.URL.RawQuery = q.Encode()
extensions.RevokeAPIKey(response, request, mocks.RepoDBMock{
DeleteUserAPIKeyFn: func(ctx context.Context, id string) error {
return ErrUnexpectedError
},
}, ctlr.CookieStore, ctlr.Log)
resp = response.Result()
defer resp.Body.Close()
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
})
Convey("Helper functions", func() {
testUpdateBlobUpload := func(
query []struct{ k, v string },