mirror of
https://github.com/project-zot/zot.git
synced 2026-06-16 04:17:55 +08:00
feat: integrate openID auth logic and user profile management (#1381)
This change introduces OpenID authn by using providers such as Github, Gitlab, Google and Dex. User sessions are now used for web clients to identify and persist an authenticated users session, thus not requiring every request to use credentials. Another change is apikey feature, users can create/revoke their api keys and use them to authenticate when using cli clients such as skopeo. eg: login: /auth/login?provider=github /auth/login?provider=gitlab and so on logout: /auth/logout redirectURL: /auth/callback/github /auth/callback/gitlab and so on If network policy doesn't allow inbound connections, this callback wont work! for more info read documentation added in this commit. Signed-off-by: Alex Stan <alexandrustan96@yahoo.ro> Signed-off-by: Petu Eusebiu <peusebiu@cisco.com> Co-authored-by: Alex Stan <alexandrustan96@yahoo.ro>
This commit is contained in:
+678
-163
@@ -3,139 +3,315 @@ package api
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/chartmuseum/auth"
|
||||
"github.com/google/go-github/v52/github"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
godigest "github.com/opencontainers/go-digest"
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/oauth2"
|
||||
githubOAuth "golang.org/x/oauth2/github"
|
||||
|
||||
"zotregistry.io/zot/errors"
|
||||
zerr "zotregistry.io/zot/errors"
|
||||
"zotregistry.io/zot/pkg/api/config"
|
||||
"zotregistry.io/zot/pkg/api/constants"
|
||||
apiErr "zotregistry.io/zot/pkg/api/errors"
|
||||
"zotregistry.io/zot/pkg/common"
|
||||
"zotregistry.io/zot/pkg/log"
|
||||
localCtx "zotregistry.io/zot/pkg/requestcontext"
|
||||
storageConstants "zotregistry.io/zot/pkg/storage/constants"
|
||||
)
|
||||
|
||||
const (
|
||||
bearerAuthDefaultAccessEntryType = "repository"
|
||||
issuedAtOffset = 5 * time.Second
|
||||
relyingPartyCookieMaxAge = 120
|
||||
)
|
||||
|
||||
func AuthHandler(c *Controller) mux.MiddlewareFunc {
|
||||
if isBearerAuthEnabled(c.Config) {
|
||||
return bearerAuthHandler(c)
|
||||
}
|
||||
|
||||
return basicAuthHandler(c)
|
||||
type AuthnMiddleware struct {
|
||||
credMap map[string]string
|
||||
ldapClient *LDAPClient
|
||||
}
|
||||
|
||||
func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
authorizer, err := auth.NewAuthorizer(&auth.AuthorizerOptions{
|
||||
Realm: ctlr.Config.HTTP.Auth.Bearer.Realm,
|
||||
Service: ctlr.Config.HTTP.Auth.Bearer.Service,
|
||||
PublicKeyPath: ctlr.Config.HTTP.Auth.Bearer.Cert,
|
||||
AccessEntryType: bearerAuthDefaultAccessEntryType,
|
||||
EmptyDefaultNamespace: true,
|
||||
})
|
||||
func AuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
authnMiddleware := &AuthnMiddleware{}
|
||||
|
||||
if isBearerAuthEnabled(ctlr.Config) {
|
||||
return bearerAuthHandler(ctlr)
|
||||
}
|
||||
|
||||
return authnMiddleware.TryAuthnHandlers(ctlr)
|
||||
}
|
||||
|
||||
func (amw *AuthnMiddleware) sessionAuthn(ctlr *Controller, next http.Handler, response http.ResponseWriter,
|
||||
request *http.Request, delay int,
|
||||
) {
|
||||
clientHeader := request.Header.Get(constants.SessionClientHeaderName)
|
||||
if clientHeader != constants.SessionClientHeaderValue {
|
||||
authFail(response, request, ctlr.Config.HTTP.Realm, delay)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
identity, ok := common.GetAuthUserFromRequestSession(ctlr.CookieStore, request, ctlr.Log)
|
||||
if !ok {
|
||||
// let the client know that this session is invalid/expired
|
||||
cookie := &http.Cookie{
|
||||
Name: "session",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
|
||||
HttpOnly: true,
|
||||
}
|
||||
|
||||
http.SetCookie(response, cookie)
|
||||
|
||||
authFail(response, request, ctlr.Config.HTTP.Realm, delay)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getReqContextWithAuthorization(identity, []string{}, request)
|
||||
|
||||
groups, err := ctlr.RepoDB.GetUserGroups(ctx)
|
||||
if err != nil {
|
||||
ctlr.Log.Panic().Err(err).Msg("error creating bearer authorizer")
|
||||
if errors.Is(err, zerr.ErrUserDataNotFound) {
|
||||
ctlr.Log.Err(err).Str("identity", identity).Msg("can not find user profile in DB")
|
||||
|
||||
authFail(response, request, ctlr.Config.HTTP.Realm, delay)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctlr.Log.Err(err).Str("identity", identity).Msg("can not get user profile in DB")
|
||||
|
||||
response.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
if request.Method == http.MethodOptions {
|
||||
next.ServeHTTP(response, request)
|
||||
response.WriteHeader(http.StatusNoContent)
|
||||
ctx = getReqContextWithAuthorization(identity, groups, request)
|
||||
|
||||
return
|
||||
}
|
||||
vars := mux.Vars(request)
|
||||
name := vars["name"]
|
||||
|
||||
// we want to bypass auth for mgmt route
|
||||
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
|
||||
|
||||
header := request.Header.Get("Authorization")
|
||||
|
||||
if (header == "" || header == "Basic Og==") && isMgmtRequested {
|
||||
next.ServeHTTP(response, request)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
action := auth.PullAction
|
||||
if m := request.Method; m != http.MethodGet && m != http.MethodHead {
|
||||
action = auth.PushAction
|
||||
}
|
||||
permissions, err := authorizer.Authorize(header, action, name)
|
||||
if err != nil {
|
||||
ctlr.Log.Error().Err(err).Msg("issue parsing Authorization header")
|
||||
response.Header().Set("Content-Type", "application/json")
|
||||
common.WriteJSON(response, http.StatusInternalServerError, apiErr.NewErrorList(apiErr.NewError(apiErr.UNSUPPORTED)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !permissions.Allowed {
|
||||
authFail(response, permissions.WWWAuthenticateHeader, 0)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(response, request)
|
||||
})
|
||||
}
|
||||
next.ServeHTTP(response, request.WithContext(ctx))
|
||||
}
|
||||
|
||||
func noPasswdAuth(realm string, config *config.Config) mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
if request.Method == http.MethodOptions {
|
||||
next.ServeHTTP(response, request)
|
||||
response.WriteHeader(http.StatusNoContent)
|
||||
func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, response http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) (bool, http.ResponseWriter, *http.Request, error) {
|
||||
cookieStore := ctlr.CookieStore
|
||||
|
||||
return
|
||||
}
|
||||
// we want to bypass auth for mgmt route
|
||||
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
|
||||
|
||||
// Process request
|
||||
if request.Header.Get("Authorization") == "" {
|
||||
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
|
||||
ctx := getReqContextWithAuthorization("", []string{}, request)
|
||||
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
|
||||
})
|
||||
// Process request
|
||||
|
||||
return true, response, request.WithContext(ctx), nil
|
||||
}
|
||||
}
|
||||
|
||||
identity, passphrase, err := getUsernamePasswordBasicAuth(request)
|
||||
if err != nil {
|
||||
ctlr.Log.Error().Err(err).Msg("failed to parse authorization header")
|
||||
|
||||
return false, nil, nil, nil
|
||||
}
|
||||
|
||||
// some client tools might send Authorization: Basic Og== (decoded into ":")
|
||||
// empty username and password
|
||||
if identity == "" && passphrase == "" {
|
||||
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
|
||||
ctx := getReqContextWithAuthorization("", []string{}, request)
|
||||
|
||||
return true, response, request.WithContext(ctx), nil
|
||||
}
|
||||
}
|
||||
|
||||
passphraseHash, ok := amw.credMap[identity]
|
||||
if ok {
|
||||
// first, HTTPPassword authN (which is local)
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(passphraseHash), []byte(passphrase)); err == nil {
|
||||
// Process request
|
||||
var groups []string
|
||||
|
||||
if ctlr.Config.HTTP.AccessControl != nil {
|
||||
ac := NewAccessController(ctlr.Config)
|
||||
groups = ac.getUserGroups(identity)
|
||||
}
|
||||
|
||||
ctx := getReqContextWithAuthorization(identity, groups, request)
|
||||
|
||||
// saved logged session
|
||||
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
|
||||
return false, response, request, err
|
||||
}
|
||||
|
||||
if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil {
|
||||
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("couldn't update user profile")
|
||||
|
||||
return false, response, request, err
|
||||
}
|
||||
|
||||
ctlr.Log.Info().Str("identity", identity).Msgf("user profile successfully set")
|
||||
|
||||
return true, response, request.WithContext(ctx), nil
|
||||
}
|
||||
}
|
||||
|
||||
// next, LDAP if configured (network-based which can lose connectivity)
|
||||
if ctlr.Config.HTTP.Auth != nil && ctlr.Config.HTTP.Auth.LDAP != nil {
|
||||
ok, _, ldapgroups, err := amw.ldapClient.Authenticate(identity, passphrase)
|
||||
if ok && err == nil {
|
||||
// Process request
|
||||
var groups []string
|
||||
|
||||
if ctlr.Config.HTTP.AccessControl != nil {
|
||||
ac := NewAccessController(ctlr.Config)
|
||||
groups = ac.getUserGroups(identity)
|
||||
}
|
||||
|
||||
groups = append(groups, ldapgroups...)
|
||||
|
||||
ctx := getReqContextWithAuthorization(identity, groups, request)
|
||||
|
||||
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
|
||||
return false, response, request, err
|
||||
}
|
||||
|
||||
if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil {
|
||||
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("couldn't update user profile")
|
||||
|
||||
return false, response, request, err
|
||||
}
|
||||
|
||||
return true, response, request.WithContext(ctx), nil
|
||||
}
|
||||
}
|
||||
|
||||
// last try API keys
|
||||
if isAPIKeyEnabled(ctlr.Config) {
|
||||
apiKey := passphrase
|
||||
|
||||
if !strings.HasPrefix(apiKey, constants.APIKeysPrefix) {
|
||||
ctlr.Log.Error().Msg("api token has invalid format")
|
||||
|
||||
return false, nil, nil, nil
|
||||
}
|
||||
|
||||
trimmedAPIKey := strings.TrimPrefix(apiKey, constants.APIKeysPrefix)
|
||||
|
||||
hashedKey := hashUUID(trimmedAPIKey)
|
||||
|
||||
storedIdentity, err := ctlr.RepoDB.GetUserAPIKeyInfo(hashedKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, zerr.ErrUserAPIKeyNotFound) {
|
||||
ctlr.Log.Info().Err(err).Msgf("can not find any user info for hashed key %s in DB", hashedKey)
|
||||
|
||||
return false, nil, nil, nil
|
||||
}
|
||||
|
||||
ctlr.Log.Error().Err(err).Msgf("can not get user info for hashed key %s in DB", hashedKey)
|
||||
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
if storedIdentity == identity {
|
||||
ctx := getReqContextWithAuthorization(identity, []string{}, request)
|
||||
|
||||
err := ctlr.RepoDB.UpdateUserAPIKeyLastUsed(ctx, hashedKey)
|
||||
if err != nil {
|
||||
ctlr.Log.Err(err).Str("identity", identity).Msg("can not update user profile in DB")
|
||||
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
groups, err := ctlr.RepoDB.GetUserGroups(ctx)
|
||||
if err != nil {
|
||||
ctlr.Log.Err(err).Str("identity", identity).Msg("can not get user's groups in DB")
|
||||
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
ctx = getReqContextWithAuthorization(identity, groups, request)
|
||||
|
||||
return true, response, request.WithContext(ctx), nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil, nil, nil
|
||||
}
|
||||
|
||||
//nolint:gocyclo // we use closure making this a complex subroutine
|
||||
func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
realm := ctlr.Config.HTTP.Realm
|
||||
if realm == "" {
|
||||
realm = "Authorization Required"
|
||||
}
|
||||
|
||||
realm = "Basic realm=" + strconv.Quote(realm)
|
||||
|
||||
func (amw *AuthnMiddleware) TryAuthnHandlers(ctlr *Controller) mux.MiddlewareFunc { //nolint: gocyclo
|
||||
// no password based authN, if neither LDAP nor HTTP BASIC is enabled
|
||||
if ctlr.Config.HTTP.Auth == nil ||
|
||||
(ctlr.Config.HTTP.Auth.HTPasswd.Path == "" && ctlr.Config.HTTP.Auth.LDAP == nil) {
|
||||
return noPasswdAuth(realm, ctlr.Config)
|
||||
(ctlr.Config.HTTP.Auth.HTPasswd.Path == "" && ctlr.Config.HTTP.Auth.LDAP == nil &&
|
||||
ctlr.Config.HTTP.Auth.OpenID == nil) {
|
||||
return noPasswdAuth(ctlr.Config)
|
||||
}
|
||||
|
||||
credMap := make(map[string]string)
|
||||
amw.credMap = make(map[string]string)
|
||||
|
||||
delay := ctlr.Config.HTTP.Auth.FailDelay
|
||||
|
||||
var ldapClient *LDAPClient
|
||||
// setup sessions cookie store used to preserve logged in user in web sessions
|
||||
if isAuthnEnabled(ctlr.Config) || isOpenIDAuthEnabled(ctlr.Config) {
|
||||
// To store custom types in our cookies,
|
||||
// we must first register them using gob.Register
|
||||
gob.Register(map[string]interface{}{})
|
||||
|
||||
cookieStoreHashKey := securecookie.GenerateRandomKey(64)
|
||||
if cookieStoreHashKey == nil {
|
||||
panic(zerr.ErrHashKeyNotCreated)
|
||||
}
|
||||
|
||||
// if storage is filesystem then use zot's rootDir to store sessions
|
||||
if ctlr.Config.Storage.StorageDriver == nil {
|
||||
sessionsDir := path.Join(ctlr.Config.Storage.RootDirectory, "_sessions")
|
||||
if err := os.MkdirAll(sessionsDir, storageConstants.DefaultDirPerms); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
cookieStore := sessions.NewFilesystemStore(sessionsDir, cookieStoreHashKey)
|
||||
|
||||
cookieStore.MaxAge(cookiesMaxAge)
|
||||
|
||||
ctlr.CookieStore = cookieStore
|
||||
} else {
|
||||
cookieStore := sessions.NewCookieStore(cookieStoreHashKey)
|
||||
|
||||
cookieStore.MaxAge(cookiesMaxAge)
|
||||
|
||||
ctlr.CookieStore = cookieStore
|
||||
}
|
||||
}
|
||||
|
||||
// ldap and htpasswd based authN
|
||||
if ctlr.Config.HTTP.Auth != nil {
|
||||
if ctlr.Config.HTTP.Auth.LDAP != nil {
|
||||
ldapConfig := ctlr.Config.HTTP.Auth.LDAP
|
||||
ldapClient = &LDAPClient{
|
||||
amw.ldapClient = &LDAPClient{
|
||||
Host: ldapConfig.Address,
|
||||
Port: ldapConfig.Port,
|
||||
UseSSL: !ldapConfig.Insecure,
|
||||
@@ -160,18 +336,18 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
caCertPool := x509.NewCertPool()
|
||||
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
panic(errors.ErrBadCACert)
|
||||
panic(zerr.ErrBadCACert)
|
||||
}
|
||||
|
||||
ldapClient.ClientCAs = caCertPool
|
||||
amw.ldapClient.ClientCAs = caCertPool
|
||||
} else {
|
||||
// default to system cert pool
|
||||
caCertPool, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
panic(errors.ErrBadCACert)
|
||||
panic(zerr.ErrBadCACert)
|
||||
}
|
||||
|
||||
ldapClient.ClientCAs = caCertPool
|
||||
amw.ldapClient.ClientCAs = caCertPool
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,12 +364,27 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
line := scanner.Text()
|
||||
if strings.Contains(line, ":") {
|
||||
tokens := strings.Split(scanner.Text(), ":")
|
||||
credMap[tokens[0]] = tokens[1]
|
||||
amw.credMap[tokens[0]] = tokens[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// openid based authN
|
||||
if ctlr.Config.HTTP.Auth.OpenID != nil {
|
||||
ctlr.RelyingParties = make(map[string]rp.RelyingParty)
|
||||
|
||||
for provider := range ctlr.Config.HTTP.Auth.OpenID.Providers {
|
||||
if IsOpenIDSupported(provider) {
|
||||
rp := NewRelyingPartyOIDC(ctlr.Config, provider)
|
||||
ctlr.RelyingParties[provider] = rp
|
||||
} else if IsOauth2Supported(provider) {
|
||||
rp := NewRelyingPartyGithub(ctlr.Config, provider)
|
||||
ctlr.RelyingParties[provider] = rp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
if request.Method == http.MethodOptions {
|
||||
@@ -203,84 +394,231 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// we want to bypass auth for mgmt route
|
||||
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
|
||||
|
||||
if request.Header.Get("Authorization") == "" {
|
||||
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
|
||||
// Process request
|
||||
ctx := getReqContextWithAuthorization("", []string{}, request)
|
||||
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
username, passphrase, err := getUsernamePasswordBasicAuth(request)
|
||||
//nolint: contextcheck
|
||||
authenticated, cloneResp, cloneReq, err := amw.basicAuthn(ctlr, response, request)
|
||||
if err != nil {
|
||||
ctlr.Log.Error().Err(err).Msg("failed to parse authorization header")
|
||||
authFail(response, realm, delay)
|
||||
response.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// some client tools might send Authorization: Basic Og== (decoded into ":")
|
||||
// empty username and password
|
||||
if username == "" && passphrase == "" {
|
||||
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
|
||||
// Process request
|
||||
ctx := getReqContextWithAuthorization("", []string{}, request)
|
||||
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
|
||||
if authenticated && cloneResp != nil && cloneReq != nil {
|
||||
next.ServeHTTP(cloneResp, cloneReq)
|
||||
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// first, HTTPPassword authN (which is local)
|
||||
passphraseHash, ok := credMap[username]
|
||||
if ok {
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(passphraseHash), []byte(passphrase)); err == nil {
|
||||
// Process request
|
||||
var userGroups []string
|
||||
|
||||
if ctlr.Config.HTTP.AccessControl != nil {
|
||||
ac := NewAccessController(ctlr.Config)
|
||||
userGroups = ac.getUserGroups(username)
|
||||
}
|
||||
|
||||
ctx := getReqContextWithAuthorization(username, userGroups, request)
|
||||
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// next, LDAP if configured (network-based which can lose connectivity)
|
||||
if ctlr.Config.HTTP.Auth != nil && ctlr.Config.HTTP.Auth.LDAP != nil {
|
||||
ok, _, ldapgroups, err := ldapClient.Authenticate(username, passphrase)
|
||||
if ok && err == nil {
|
||||
// Process request
|
||||
var userGroups []string
|
||||
|
||||
if ctlr.Config.HTTP.AccessControl != nil {
|
||||
ac := NewAccessController(ctlr.Config)
|
||||
userGroups = ac.getUserGroups(username)
|
||||
}
|
||||
|
||||
userGroups = append(userGroups, ldapgroups...)
|
||||
|
||||
ctx := getReqContextWithAuthorization(username, userGroups, request)
|
||||
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
authFail(response, realm, delay)
|
||||
//nolint: contextcheck
|
||||
amw.sessionAuthn(ctlr, next, response, request, delay)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
authorizer, err := auth.NewAuthorizer(&auth.AuthorizerOptions{
|
||||
Realm: ctlr.Config.HTTP.Auth.Bearer.Realm,
|
||||
Service: ctlr.Config.HTTP.Auth.Bearer.Service,
|
||||
PublicKeyPath: ctlr.Config.HTTP.Auth.Bearer.Cert,
|
||||
AccessEntryType: bearerAuthDefaultAccessEntryType,
|
||||
EmptyDefaultNamespace: true,
|
||||
})
|
||||
if err != nil {
|
||||
ctlr.Log.Panic().Err(err).Msg("error creating bearer authorizer")
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
if request.Method == http.MethodOptions {
|
||||
next.ServeHTTP(response, request)
|
||||
response.WriteHeader(http.StatusNoContent)
|
||||
|
||||
return
|
||||
}
|
||||
acCtrlr := NewAccessController(ctlr.Config)
|
||||
vars := mux.Vars(request)
|
||||
name := vars["name"]
|
||||
|
||||
// we want to bypass auth for mgmt route
|
||||
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
|
||||
|
||||
header := request.Header.Get("Authorization")
|
||||
|
||||
if (header == "" || header == "Basic Og==") && isMgmtRequested {
|
||||
next.ServeHTTP(response, request)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
action := auth.PullAction
|
||||
if m := request.Method; m != http.MethodGet && m != http.MethodHead {
|
||||
action = auth.PushAction
|
||||
}
|
||||
|
||||
permissions, err := authorizer.Authorize(header, action, name)
|
||||
if err != nil {
|
||||
ctlr.Log.Error().Err(err).Msg("issue parsing Authorization header")
|
||||
response.Header().Set("Content-Type", "application/json")
|
||||
common.WriteJSON(response, http.StatusInternalServerError, apiErr.NewErrorList(apiErr.NewError(apiErr.UNSUPPORTED)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !permissions.Allowed {
|
||||
response.Header().Set("Content-Type", "application/json")
|
||||
response.Header().Set("WWW-Authenticate", permissions.WWWAuthenticateHeader)
|
||||
|
||||
common.WriteJSON(response, http.StatusUnauthorized,
|
||||
apiErr.NewErrorList(apiErr.NewError(apiErr.UNAUTHORIZED)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
amCtx := acCtrlr.getAuthnMiddlewareContext(BEARER, request)
|
||||
next.ServeHTTP(response, request.WithContext(amCtx)) //nolint:contextcheck
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func noPasswdAuth(config *config.Config) mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
if request.Method == http.MethodOptions {
|
||||
next.ServeHTTP(response, request)
|
||||
response.WriteHeader(http.StatusNoContent)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getReqContextWithAuthorization("", []string{}, request)
|
||||
// Process request
|
||||
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (rh *RouteHandler) AuthURLHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
callbackUI := query.Get(constants.CallbackUIQueryParam)
|
||||
|
||||
provider := query.Get("provider")
|
||||
|
||||
client, ok := rh.c.RelyingParties[provider]
|
||||
if !ok {
|
||||
http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
response.WriteHeader(http.StatusBadRequest)
|
||||
})(w, r)
|
||||
}
|
||||
|
||||
/* save cookie containing state to later verify it and
|
||||
callback ui where we will redirect after openid/oauth2 logic is completed*/
|
||||
session, _ := rh.c.CookieStore.Get(r, "statecookie")
|
||||
|
||||
session.Options.Secure = true
|
||||
session.Options.HttpOnly = true
|
||||
session.Options.SameSite = http.SameSiteDefaultMode
|
||||
session.Options.Path = constants.CallbackBasePath
|
||||
|
||||
state := uuid.New().String()
|
||||
|
||||
session.Values["state"] = state
|
||||
session.Values["callback"] = callbackUI
|
||||
|
||||
// let the session set its own id
|
||||
err := session.Save(r, w)
|
||||
if err != nil {
|
||||
rh.c.Log.Error().Err(err).Msg("unable to save http session")
|
||||
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
stateFunc := func() string {
|
||||
return state
|
||||
}
|
||||
|
||||
rp.AuthURLHandler(stateFunc, client)(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func NewRelyingPartyOIDC(config *config.Config, provider string) rp.RelyingParty {
|
||||
issuer, clientID, clientSecret, redirectURI, scopes, options := getRelyingPartyArgs(config, provider)
|
||||
|
||||
relyingParty, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return relyingParty
|
||||
}
|
||||
|
||||
func NewRelyingPartyGithub(config *config.Config, provider string) rp.RelyingParty {
|
||||
_, clientID, clientSecret, redirectURI, scopes, options := getRelyingPartyArgs(config, provider)
|
||||
|
||||
rpConfig := &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURI,
|
||||
Scopes: scopes,
|
||||
Endpoint: githubOAuth.Endpoint,
|
||||
}
|
||||
|
||||
relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return relyingParty
|
||||
}
|
||||
|
||||
func getRelyingPartyArgs(config *config.Config, provider string) (
|
||||
string, string, string, string, []string, []rp.Option,
|
||||
) {
|
||||
if _, ok := config.HTTP.Auth.OpenID.Providers[provider]; !ok {
|
||||
panic(zerr.ErrOpenIDProviderDoesNotExist)
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if config.HTTP.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
clientID := config.HTTP.Auth.OpenID.Providers[provider].ClientID
|
||||
clientSecret := config.HTTP.Auth.OpenID.Providers[provider].ClientSecret
|
||||
|
||||
scopes := config.HTTP.Auth.OpenID.Providers[provider].Scopes
|
||||
// openid scope must be the first one in list
|
||||
if !common.Contains(scopes, oidc.ScopeOpenID) && IsOpenIDSupported(provider) {
|
||||
scopes = append([]string{oidc.ScopeOpenID}, scopes...)
|
||||
}
|
||||
|
||||
port := config.HTTP.Port
|
||||
issuer := config.HTTP.Auth.OpenID.Providers[provider].Issuer
|
||||
keyPath := config.HTTP.Auth.OpenID.Providers[provider].KeyPath
|
||||
baseURL := net.JoinHostPort(config.HTTP.Address, port)
|
||||
redirectURI := fmt.Sprintf("%s://%s%s", scheme, baseURL, constants.CallbackBasePath+fmt.Sprintf("/%s", provider))
|
||||
|
||||
options := []rp.Option{
|
||||
rp.WithVerifierOpts(rp.WithIssuedAtOffset(issuedAtOffset)),
|
||||
}
|
||||
|
||||
key := securecookie.GenerateRandomKey(32) //nolint: gomnd
|
||||
|
||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithMaxAge(relyingPartyCookieMaxAge))
|
||||
options = append(options, rp.WithCookieHandler(cookieHandler))
|
||||
|
||||
if clientSecret == "" {
|
||||
options = append(options, rp.WithPKCE(cookieHandler))
|
||||
}
|
||||
|
||||
if keyPath != "" {
|
||||
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
|
||||
}
|
||||
|
||||
return issuer, clientID, clientSecret, redirectURI, scopes, options
|
||||
}
|
||||
|
||||
func getReqContextWithAuthorization(username string, groups []string, request *http.Request) context.Context {
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: username,
|
||||
@@ -314,9 +652,71 @@ func isBearerAuthEnabled(config *config.Config) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func authFail(w http.ResponseWriter, realm string, delay int) {
|
||||
func isOpenIDAuthEnabled(config *config.Config) bool {
|
||||
if config.HTTP.Auth != nil &&
|
||||
config.HTTP.Auth.OpenID != nil {
|
||||
for provider := range config.HTTP.Auth.OpenID.Providers {
|
||||
if isOpenIDAuthProviderEnabled(config, provider) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isAPIKeyEnabled(config *config.Config) bool {
|
||||
if config.Extensions != nil && config.Extensions.APIKey != nil &&
|
||||
*config.Extensions.APIKey.Enable {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isOpenIDAuthProviderEnabled(config *config.Config, provider string) bool {
|
||||
if providerConfig, ok := config.HTTP.Auth.OpenID.Providers[provider]; ok {
|
||||
if IsOpenIDSupported(provider) {
|
||||
if providerConfig.ClientID != "" || providerConfig.Issuer != "" ||
|
||||
len(providerConfig.Scopes) > 0 {
|
||||
return true
|
||||
}
|
||||
} else if IsOauth2Supported(provider) {
|
||||
if providerConfig.ClientID != "" || len(providerConfig.Scopes) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func IsOpenIDSupported(provider string) bool {
|
||||
supported := []string{"google", "gitlab", "dex"}
|
||||
|
||||
return common.Contains(supported, provider)
|
||||
}
|
||||
|
||||
func IsOauth2Supported(provider string) bool {
|
||||
supported := []string{"github"}
|
||||
|
||||
return common.Contains(supported, provider)
|
||||
}
|
||||
|
||||
func authFail(w http.ResponseWriter, r *http.Request, realm string, delay int) {
|
||||
time.Sleep(time.Duration(delay) * time.Second)
|
||||
w.Header().Set("WWW-Authenticate", realm)
|
||||
|
||||
// don't send auth headers if request is coming from UI
|
||||
if r.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue {
|
||||
if realm == "" {
|
||||
realm = "Authorization Required"
|
||||
}
|
||||
|
||||
realm = "Basic realm=" + strconv.Quote(realm)
|
||||
|
||||
w.Header().Set("WWW-Authenticate", realm)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
common.WriteJSON(w, http.StatusUnauthorized, apiErr.NewErrorList(apiErr.NewError(apiErr.UNAUTHORIZED)))
|
||||
}
|
||||
@@ -325,12 +725,12 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error)
|
||||
basicAuth := request.Header.Get("Authorization")
|
||||
|
||||
if basicAuth == "" {
|
||||
return "", "", errors.ErrParsingAuthHeader
|
||||
return "", "", zerr.ErrParsingAuthHeader
|
||||
}
|
||||
|
||||
splitStr := strings.SplitN(basicAuth, " ", 2) //nolint:gomnd
|
||||
splitStr := strings.SplitN(basicAuth, " ", 2) //nolint: gomnd
|
||||
if len(splitStr) != 2 || strings.ToLower(splitStr[0]) != "basic" {
|
||||
return "", "", errors.ErrParsingAuthHeader
|
||||
return "", "", zerr.ErrParsingAuthHeader
|
||||
}
|
||||
|
||||
decodedStr, err := base64.StdEncoding.DecodeString(splitStr[1])
|
||||
@@ -338,9 +738,9 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error)
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
pair := strings.SplitN(string(decodedStr), ":", 2) //nolint:gomnd
|
||||
if len(pair) != 2 { //nolint:gomnd
|
||||
return "", "", errors.ErrParsingAuthHeader
|
||||
pair := strings.SplitN(string(decodedStr), ":", 2) //nolint: gomnd
|
||||
if len(pair) != 2 { //nolint: gomnd
|
||||
return "", "", zerr.ErrParsingAuthHeader
|
||||
}
|
||||
|
||||
username := pair[0]
|
||||
@@ -348,3 +748,118 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error)
|
||||
|
||||
return username, passphrase, nil
|
||||
}
|
||||
|
||||
func GetGithubUserInfo(ctx context.Context, client *github.Client, log log.Logger) (string, []string, error) {
|
||||
var primaryEmail string
|
||||
|
||||
userEmails, _, err := client.Users.ListEmails(ctx, nil)
|
||||
if err != nil {
|
||||
log.Error().Msg("couldn't set user record for empty email value")
|
||||
|
||||
return "", []string{}, err
|
||||
}
|
||||
|
||||
if len(userEmails) != 0 {
|
||||
for _, email := range userEmails { // should have at least one primary email, if any
|
||||
if email.GetPrimary() { // check if it's primary email
|
||||
primaryEmail = email.GetEmail()
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
orgs, _, err := client.Organizations.List(ctx, "", nil)
|
||||
if err != nil {
|
||||
log.Error().Msg("couldn't set user record for empty email value")
|
||||
|
||||
return "", []string{}, err
|
||||
}
|
||||
|
||||
groups := []string{}
|
||||
for _, org := range orgs {
|
||||
groups = append(groups, *org.Login)
|
||||
}
|
||||
|
||||
return primaryEmail, groups, nil
|
||||
}
|
||||
|
||||
func saveUserLoggedSession(cookieStore sessions.Store, response http.ResponseWriter,
|
||||
request *http.Request, identity string, log log.Logger,
|
||||
) error {
|
||||
session, _ := cookieStore.Get(request, "session")
|
||||
|
||||
session.Options.Secure = true
|
||||
session.Options.HttpOnly = true
|
||||
session.Options.SameSite = http.SameSiteDefaultMode
|
||||
session.Values["authStatus"] = true
|
||||
session.Values["user"] = identity
|
||||
|
||||
// let the session set its own id
|
||||
err := session.Save(request, response)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("identity", identity).Msg("unable to save http session")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
userInfoCookie := sessions.NewCookie("user", identity, &sessions.Options{
|
||||
Secure: true,
|
||||
HttpOnly: false,
|
||||
MaxAge: cookiesMaxAge,
|
||||
SameSite: http.SameSiteDefaultMode,
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
http.SetCookie(response, userInfoCookie)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OAuth2Callback is the callback logic where openid/oauth2 will redirect back to our app.
|
||||
func OAuth2Callback(ctlr *Controller, w http.ResponseWriter, r *http.Request, state, email string,
|
||||
groups []string,
|
||||
) (string, error) {
|
||||
stateCookie, _ := ctlr.CookieStore.Get(r, "statecookie")
|
||||
|
||||
stateOrigin, ok := stateCookie.Values["state"].(string)
|
||||
if !ok {
|
||||
ctlr.Log.Error().Err(zerr.ErrInvalidStateCookie).Msg("openID: unable to get 'state' cookie from request")
|
||||
|
||||
return "", zerr.ErrInvalidStateCookie
|
||||
}
|
||||
|
||||
if stateOrigin != state {
|
||||
ctlr.Log.Error().Err(zerr.ErrInvalidStateCookie).Msg("openID: 'state' cookie differs from the actual one")
|
||||
|
||||
return "", zerr.ErrInvalidStateCookie
|
||||
}
|
||||
|
||||
ctx := getReqContextWithAuthorization(email, groups, r)
|
||||
|
||||
// if this line has been reached, then a new session should be created
|
||||
// if the `session` key is already on the cookie, it's not a valid one
|
||||
if err := saveUserLoggedSession(ctlr.CookieStore, w, r, email, ctlr.Log); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil {
|
||||
ctlr.Log.Error().Err(err).Str("identity", email).Msg("couldn't update the user profile")
|
||||
|
||||
return "", err
|
||||
}
|
||||
|
||||
ctlr.Log.Info().Msgf("user profile set successfully for email %s", email)
|
||||
|
||||
// redirect to UI
|
||||
callbackUI, _ := stateCookie.Values["callback"].(string)
|
||||
|
||||
return callbackUI, nil
|
||||
}
|
||||
|
||||
func hashUUID(uuid string) string {
|
||||
digester := sha256.New()
|
||||
digester.Write([]byte(uuid))
|
||||
|
||||
return godigest.NewDigestFromEncoded(godigest.SHA256, fmt.Sprintf("%x", digester.Sum(nil))).Encoded()
|
||||
}
|
||||
|
||||
+45
-11
@@ -21,6 +21,9 @@ const (
|
||||
Delete = "delete"
|
||||
// behaviour actions.
|
||||
DetectManifestCollision = "detectManifestCollision"
|
||||
BASIC = "Basic"
|
||||
BEARER = "Bearer"
|
||||
OPENID = "OpenID"
|
||||
)
|
||||
|
||||
// AccessController authorizes users to act on resources.
|
||||
@@ -29,10 +32,17 @@ type AccessController struct {
|
||||
Log log.Logger
|
||||
}
|
||||
|
||||
func NewAccessController(config *config.Config) *AccessController {
|
||||
func NewAccessController(conf *config.Config) *AccessController {
|
||||
if conf.HTTP.AccessControl == nil {
|
||||
return &AccessController{
|
||||
Config: &config.AccessControlConfig{},
|
||||
Log: log.NewLogger(conf.Log.Level, conf.Log.Output),
|
||||
}
|
||||
}
|
||||
|
||||
return &AccessController{
|
||||
Config: config.HTTP.AccessControl,
|
||||
Log: log.NewLogger(config.Log.Level, config.Log.Output),
|
||||
Config: conf.HTTP.AccessControl,
|
||||
Log: log.NewLogger(conf.Log.Level, conf.Log.Output),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,6 +181,18 @@ func (ac *AccessController) getContext(acCtx *localCtx.AccessControlContext, req
|
||||
return ctx
|
||||
}
|
||||
|
||||
// getAuthnMiddlewareContext builds ac context(allowed to read repos and if user is admin) and returns it.
|
||||
func (ac *AccessController) getAuthnMiddlewareContext(authnType string, request *http.Request) context.Context {
|
||||
amwCtx := localCtx.AuthnMiddlewareContext{
|
||||
AuthnType: authnType,
|
||||
}
|
||||
|
||||
amwCtxKey := localCtx.GetAuthnMiddlewareCtxKey()
|
||||
ctx := context.WithValue(request.Context(), amwCtxKey, amwCtx)
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// isPermitted returns true if username can do action on a repository policy.
|
||||
func (ac *AccessController) isPermitted(userGroups []string, username, action string,
|
||||
policyGroup config.PolicyGroup,
|
||||
@@ -231,6 +253,14 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// request comes from bearer authn, bypass it
|
||||
authnMwCtx, err := localCtx.GetAuthnMiddlewareContext(request.Context())
|
||||
if err != nil || (authnMwCtx != nil && authnMwCtx.AuthnType == BEARER) {
|
||||
next.ServeHTTP(response, request)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// bypass authz for /v2/ route
|
||||
if request.RequestURI == "/v2/" {
|
||||
next.ServeHTTP(response, request)
|
||||
@@ -242,8 +272,6 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
|
||||
var identity string
|
||||
|
||||
var err error
|
||||
|
||||
// anonymous context
|
||||
acCtx := &localCtx.AccessControlContext{}
|
||||
|
||||
@@ -252,7 +280,7 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
// get access control context made in authn.go if authn is enabled
|
||||
acCtx, err = localCtx.GetAccessControlContext(request.Context())
|
||||
if err != nil { // should never happen
|
||||
authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -272,7 +300,7 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
// if we still don't have an identity
|
||||
if identity == "" {
|
||||
acCtrlr.Log.Info().Msg("couldn't get identity from TLS certificate")
|
||||
authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -298,6 +326,14 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// request comes from bearer authn, bypass it
|
||||
authnMwCtx, err := localCtx.GetAuthnMiddlewareContext(request.Context())
|
||||
if err != nil || (authnMwCtx != nil && authnMwCtx.AuthnType == BEARER) {
|
||||
next.ServeHTTP(response, request)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(request)
|
||||
resource := vars["name"]
|
||||
reference, ok := vars["reference"]
|
||||
@@ -306,12 +342,10 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
|
||||
var identity string
|
||||
|
||||
var err error
|
||||
|
||||
// get acCtx built in authn and previous authz middlewares
|
||||
acCtx, err := localCtx.GetAccessControlContext(request.Context())
|
||||
if err != nil { // should never happen
|
||||
authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -344,7 +378,7 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
|
||||
can := acCtrlr.can(request.Context(), identity, action, resource) //nolint:contextcheck
|
||||
if !can {
|
||||
common.AuthzFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
common.AuthzFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
} else {
|
||||
next.ServeHTTP(response, request) //nolint:contextcheck
|
||||
}
|
||||
|
||||
@@ -45,6 +45,7 @@ type AuthConfig struct {
|
||||
HTPasswd AuthHTPasswd
|
||||
LDAP *LDAPConfig
|
||||
Bearer *BearerConfig
|
||||
OpenID *OpenIDConfig
|
||||
}
|
||||
|
||||
type BearerConfig struct {
|
||||
@@ -53,6 +54,18 @@ type BearerConfig struct {
|
||||
Cert string
|
||||
}
|
||||
|
||||
type OpenIDConfig struct {
|
||||
Providers map[string]OpenIDProviderConfig
|
||||
}
|
||||
|
||||
type OpenIDProviderConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
KeyPath string
|
||||
Issuer string
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
type MethodRatelimitConfig struct {
|
||||
Method string
|
||||
Rate int
|
||||
@@ -63,6 +76,7 @@ type RatelimitConfig struct {
|
||||
Methods []MethodRatelimitConfig `mapstructure:",omitempty"`
|
||||
}
|
||||
|
||||
//nolint:maligned
|
||||
type HTTPConfig struct {
|
||||
Address string
|
||||
Port string
|
||||
|
||||
@@ -12,4 +12,11 @@ const (
|
||||
DefaultMediaType = "application/json"
|
||||
BinaryMediaType = "application/octet-stream"
|
||||
DefaultMetricsExtensionRoute = "/metrics"
|
||||
CallbackBasePath = "/auth/callback"
|
||||
LoginPath = "/auth/login"
|
||||
LogoutPath = "/auth/logout"
|
||||
SessionClientHeaderName = "X-ZOT-API-CLIENT"
|
||||
SessionClientHeaderValue = "zot-ui"
|
||||
APIKeysPrefix = "zak_"
|
||||
CallbackUIQueryParam = "callback_ui"
|
||||
)
|
||||
|
||||
@@ -18,4 +18,7 @@ const (
|
||||
ExtUserPreferences = "/userprefs"
|
||||
ExtUserPreferencesPrefix = ExtPrefix + ExtUserPreferences
|
||||
FullUserPreferencesPrefix = RoutePrefix + ExtUserPreferencesPrefix
|
||||
ExtAPIKey = "/apikey"
|
||||
ExtAPIKeyPrefix = ExtPrefix + ExtAPIKey //nolint: gosec
|
||||
FullAPIKeyPrefix = RoutePrefix + ExtAPIKeyPrefix
|
||||
)
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
|
||||
"github.com/gorilla/handlers"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
|
||||
"zotregistry.io/zot/errors"
|
||||
"zotregistry.io/zot/pkg/api/config"
|
||||
@@ -31,6 +33,7 @@ import (
|
||||
const (
|
||||
idleTimeout = 120 * time.Second
|
||||
readHeaderTimeout = 5 * time.Second
|
||||
cookiesMaxAge = 86400 // seconds
|
||||
)
|
||||
|
||||
type Controller struct {
|
||||
@@ -44,6 +47,8 @@ type Controller struct {
|
||||
Metrics monitoring.MetricServer
|
||||
CveInfo ext.CveInfo
|
||||
SyncOnDemand SyncOnDemand
|
||||
RelyingParties map[string]rp.RelyingParty
|
||||
CookieStore sessions.Store
|
||||
// runtime params
|
||||
chosenPort int // kernel-chosen port
|
||||
}
|
||||
@@ -254,7 +259,9 @@ func (c *Controller) InitImageStore() error {
|
||||
}
|
||||
|
||||
func (c *Controller) InitRepoDB(reloadCtx context.Context) error {
|
||||
if c.Config.Extensions != nil && c.Config.Extensions.Search != nil && *c.Config.Extensions.Search.Enable {
|
||||
// init repoDB if search is enabled or authn enabled (need to store user profiles) or apikey ext is enabled
|
||||
if (c.Config.Extensions != nil && c.Config.Extensions.Search != nil && *c.Config.Extensions.Search.Enable) ||
|
||||
isAuthnEnabled(c.Config) || isOpenIDAuthEnabled(c.Config) || isAPIKeyEnabled(c.Config) {
|
||||
driver, err := repodbfactory.New(c.Config.Storage.StorageConfig, c.Log) //nolint:contextcheck
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
+1213
-26
File diff suppressed because it is too large
Load Diff
+164
-18
@@ -20,11 +20,14 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/go-github/v52/github"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/opencontainers/distribution-spec/specs-go/v1/extensions"
|
||||
godigest "github.com/opencontainers/go-digest"
|
||||
ispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
artifactspec "github.com/oras-project/artifacts-spec/specs-go/v1"
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
|
||||
zerr "zotregistry.io/zot/errors"
|
||||
"zotregistry.io/zot/pkg/api/constants"
|
||||
@@ -55,13 +58,38 @@ func NewRouteHandler(c *Controller) *RouteHandler {
|
||||
}
|
||||
|
||||
func (rh *RouteHandler) SetupRoutes() {
|
||||
// first get Auth middleware in order to first setup openid/ldap/htpasswd, before oidc provider routes are setup
|
||||
authHandler := AuthHandler(rh.c)
|
||||
|
||||
applyCORSHeaders := getCORSHeadersHandler(rh.c.Config.HTTP.AllowOrigin)
|
||||
|
||||
if isOpenIDAuthEnabled(rh.c.Config) {
|
||||
// login path for openID
|
||||
rh.c.Router.HandleFunc(constants.LoginPath, rh.AuthURLHandler())
|
||||
|
||||
// logout path for openID
|
||||
rh.c.Router.HandleFunc(constants.LogoutPath, applyCORSHeaders(rh.Logout)).
|
||||
Methods(zcommon.AllowedMethods("POST")...)
|
||||
|
||||
// callback path for openID
|
||||
for provider, relyingParty := range rh.c.RelyingParties {
|
||||
if IsOauth2Supported(provider) {
|
||||
rh.c.Router.HandleFunc(constants.CallbackBasePath+fmt.Sprintf("/%s", provider),
|
||||
rp.CodeExchangeHandler(rh.GithubCodeExchangeCallback(), relyingParty))
|
||||
} else if IsOpenIDSupported(provider) {
|
||||
rh.c.Router.HandleFunc(constants.CallbackBasePath+fmt.Sprintf("/%s", provider),
|
||||
rp.CodeExchangeHandler(rp.UserinfoCallback(rh.OpenIDCodeExchangeCallback()), relyingParty))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prefixedRouter := rh.c.Router.PathPrefix(constants.RoutePrefix).Subrouter()
|
||||
prefixedRouter.Use(AuthHandler(rh.c))
|
||||
prefixedRouter.Use(authHandler)
|
||||
|
||||
prefixedDistSpecRouter := prefixedRouter.NewRoute().Subrouter()
|
||||
// authz is being enabled if AccessControl is specified
|
||||
// if Authn is not present AccessControl will have only default policies
|
||||
if rh.c.Config.HTTP.AccessControl != nil && !isBearerAuthEnabled(rh.c.Config) {
|
||||
if rh.c.Config.HTTP.AccessControl != nil {
|
||||
if isAuthnEnabled(rh.c.Config) {
|
||||
rh.c.Log.Info().Msg("access control is being enabled")
|
||||
} else {
|
||||
@@ -72,8 +100,6 @@ func (rh *RouteHandler) SetupRoutes() {
|
||||
prefixedDistSpecRouter.Use(DistSpecAuthzHandler(rh.c))
|
||||
}
|
||||
|
||||
applyCORSHeaders := getCORSHeadersHandler(rh.c.Config.HTTP.AllowOrigin)
|
||||
|
||||
// https://github.com/opencontainers/distribution-spec/blob/main/spec.md#endpoints
|
||||
{
|
||||
prefixedDistSpecRouter.HandleFunc(fmt.Sprintf("/{name:%s}/tags/list", zreg.NameRegexp.String()),
|
||||
@@ -118,7 +144,7 @@ func (rh *RouteHandler) SetupRoutes() {
|
||||
constants.ArtifactSpecRoutePrefix, zreg.NameRegexp.String()), rh.GetOrasReferrers).Methods("GET")
|
||||
|
||||
// swagger
|
||||
debug.SetupSwaggerRoutes(rh.c.Config, rh.c.Router, AuthHandler(rh.c), rh.c.Log)
|
||||
debug.SetupSwaggerRoutes(rh.c.Config, rh.c.Router, authHandler, rh.c.Log)
|
||||
|
||||
// Setup Extensions Routes
|
||||
if rh.c.Config != nil {
|
||||
@@ -135,8 +161,8 @@ func (rh *RouteHandler) SetupRoutes() {
|
||||
rh.c.Log)
|
||||
ext.SetupUserPreferencesRoutes(rh.c.Config, prefixedExtensionsRouter, rh.c.StoreController, rh.c.RepoDB,
|
||||
rh.c.CveInfo, rh.c.Log)
|
||||
|
||||
ext.SetupMetricsRoutes(rh.c.Config, rh.c.Router, rh.c.StoreController, AuthHandler(rh.c), rh.c.Log)
|
||||
ext.SetupAPIKeyRoutes(rh.c.Config, prefixedExtensionsRouter, rh.c.RepoDB, rh.c.CookieStore, rh.c.Log)
|
||||
ext.SetupMetricsRoutes(rh.c.Config, rh.c.Router, rh.c.StoreController, authHandler, rh.c.Log)
|
||||
|
||||
gqlPlayground.SetupGQLPlaygroundRoutes(rh.c.Config, prefixedRouter, rh.c.StoreController, rh.c.Log)
|
||||
|
||||
@@ -185,7 +211,8 @@ func addCORSHeaders(allowOrigin string, response http.ResponseWriter) {
|
||||
// @Success 200 {string} string "ok".
|
||||
func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -195,10 +222,13 @@ func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, reques
|
||||
// NOTE: compatibility workaround - return this header in "allowed-read" mode to allow for clients to
|
||||
// work correctly
|
||||
if rh.c.Config.HTTP.Auth != nil {
|
||||
if rh.c.Config.HTTP.Auth.Bearer != nil {
|
||||
response.Header().Set("WWW-Authenticate", fmt.Sprintf("bearer realm=%s", rh.c.Config.HTTP.Auth.Bearer.Realm))
|
||||
} else {
|
||||
response.Header().Set("WWW-Authenticate", fmt.Sprintf("basic realm=%s", rh.c.Config.HTTP.Realm))
|
||||
// don't send auth headers if request is coming from UI
|
||||
if request.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue {
|
||||
if rh.c.Config.HTTP.Auth.Bearer != nil {
|
||||
response.Header().Set("WWW-Authenticate", fmt.Sprintf("bearer realm=%s", rh.c.Config.HTTP.Auth.Bearer.Realm))
|
||||
} else {
|
||||
response.Header().Set("WWW-Authenticate", fmt.Sprintf("basic realm=%s", rh.c.Config.HTTP.Realm))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,7 +254,8 @@ type ImageTags struct {
|
||||
// @Failure 400 {string} string "bad request".
|
||||
func (rh *RouteHandler) ListTags(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -355,7 +386,8 @@ func (rh *RouteHandler) ListTags(response http.ResponseWriter, request *http.Req
|
||||
// @Failure 500 {string} string "internal server error".
|
||||
func (rh *RouteHandler) CheckManifest(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -427,7 +459,8 @@ type ExtensionList struct {
|
||||
// @Router /v2/{name}/manifests/{reference} [get].
|
||||
func (rh *RouteHandler) GetManifest(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -527,7 +560,8 @@ func getReferrers(routeHandler *RouteHandler,
|
||||
// @Router /v2/{name}/referrers/{digest} [get].
|
||||
func (rh *RouteHandler) GetReferrers(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -1576,7 +1610,8 @@ type RepositoryList struct {
|
||||
// @Router /v2/_catalog [get].
|
||||
func (rh *RouteHandler) ListRepositories(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -1642,7 +1677,8 @@ func (rh *RouteHandler) ListRepositories(response http.ResponseWriter, request *
|
||||
// @Router /v2/_oci/ext/discover [get].
|
||||
func (rh *RouteHandler) ListExtensions(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -1653,6 +1689,116 @@ func (rh *RouteHandler) ListExtensions(w http.ResponseWriter, r *http.Request) {
|
||||
zcommon.WriteJSON(w, http.StatusOK, extensionList)
|
||||
}
|
||||
|
||||
// The following routes are specific to zot and NOT part of the OCI dist-spec
|
||||
|
||||
// Logout godoc
|
||||
// @Summary Logout by removing current session
|
||||
// @Description Logout by removing current session
|
||||
// @Router /openid/auth/logout [post]
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {string} string "ok".
|
||||
// @Failure 500 {string} string "internal server error".
|
||||
func (rh *RouteHandler) Logout(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
}
|
||||
|
||||
session, _ := rh.c.CookieStore.Get(request, "session")
|
||||
session.Options.MaxAge = -1
|
||||
|
||||
err := session.Save(request, response)
|
||||
if err != nil {
|
||||
response.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
response.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// github Oauth2 CodeExchange callback.
|
||||
func (rh *RouteHandler) GithubCodeExchangeCallback() rp.CodeExchangeCallback {
|
||||
return func(w http.ResponseWriter, r *http.Request,
|
||||
tokens *oidc.Tokens, state string, relyingParty rp.RelyingParty,
|
||||
) {
|
||||
ctx := r.Context()
|
||||
|
||||
client := github.NewClient(relyingParty.OAuthConfig().Client(ctx, tokens.Token))
|
||||
|
||||
email, groups, err := GetGithubUserInfo(ctx, client, rh.c.Log)
|
||||
if email == "" || err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
callbackUI, err := OAuth2Callback(rh.c, w, r, state, email, groups) //nolint: contextcheck
|
||||
if err != nil {
|
||||
if errors.Is(err, zerr.ErrInvalidStateCookie) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if callbackUI != "" {
|
||||
http.Redirect(w, r, callbackUI, http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
}
|
||||
|
||||
// Openid CodeExchange callback.
|
||||
func (rh *RouteHandler) OpenIDCodeExchangeCallback() rp.CodeExchangeUserinfoCallback {
|
||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string,
|
||||
relyingParty rp.RelyingParty, info oidc.UserInfo,
|
||||
) {
|
||||
email := info.GetEmail()
|
||||
if email == "" {
|
||||
rh.c.Log.Error().Msg("couldn't set user record for empty email value")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var groups []string
|
||||
|
||||
val, ok := info.GetClaim("groups").([]interface{})
|
||||
if !ok {
|
||||
rh.c.Log.Info().Msgf("couldn't find any 'groups' claim for user %s", email)
|
||||
}
|
||||
|
||||
for _, group := range val {
|
||||
groups = append(groups, fmt.Sprint(group))
|
||||
}
|
||||
|
||||
callbackUI, err := OAuth2Callback(rh.c, w, r, state, email, groups)
|
||||
if err != nil {
|
||||
if errors.Is(err, zerr.ErrInvalidStateCookie) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if callbackUI != "" {
|
||||
http.Redirect(w, r, callbackUI, http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
}
|
||||
|
||||
func (rh *RouteHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
m := rh.c.Metrics.ReceiveMetrics()
|
||||
zcommon.WriteJSON(w, http.StatusOK, m)
|
||||
|
||||
+175
-4
@@ -1,26 +1,36 @@
|
||||
//go:build sync && scrub && metrics && search && lint
|
||||
// +build sync,scrub,metrics,search,lint
|
||||
//go:build sync && scrub && metrics && search && lint && apikey
|
||||
// +build sync,scrub,metrics,search,lint,apikey
|
||||
|
||||
package api_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
godigest "github.com/opencontainers/go-digest"
|
||||
ispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
"github.com/project-zot/mockoidc"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
zerr "zotregistry.io/zot/errors"
|
||||
"zotregistry.io/zot/pkg/api"
|
||||
"zotregistry.io/zot/pkg/api/config"
|
||||
"zotregistry.io/zot/pkg/api/constants"
|
||||
"zotregistry.io/zot/pkg/extensions"
|
||||
extconf "zotregistry.io/zot/pkg/extensions/config"
|
||||
"zotregistry.io/zot/pkg/meta/repodb"
|
||||
localCtx "zotregistry.io/zot/pkg/requestcontext"
|
||||
storageTypes "zotregistry.io/zot/pkg/storage/types"
|
||||
"zotregistry.io/zot/pkg/test"
|
||||
@@ -29,6 +39,8 @@ import (
|
||||
|
||||
var ErrUnexpectedError = errors.New("error: unexpected error")
|
||||
|
||||
const sessionStr = "session"
|
||||
|
||||
func TestRoutes(t *testing.T) {
|
||||
Convey("Make a new controller", t, func() {
|
||||
port := test.GetFreePort()
|
||||
@@ -36,6 +48,45 @@ func TestRoutes(t *testing.T) {
|
||||
conf := config.New()
|
||||
conf.HTTP.Port = port
|
||||
|
||||
htpasswdPath := test.MakeHtpasswdFile()
|
||||
defer os.Remove(htpasswdPath)
|
||||
mockOIDCServer, err := mockoidc.Run()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
err := mockOIDCServer.Shutdown()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
mockOIDCConfig := mockOIDCServer.Config()
|
||||
conf.HTTP.Auth = &config.AuthConfig{
|
||||
HTPasswd: config.AuthHTPasswd{
|
||||
Path: htpasswdPath,
|
||||
},
|
||||
OpenID: &config.OpenIDConfig{
|
||||
Providers: map[string]config.OpenIDProviderConfig{
|
||||
"dex": {
|
||||
ClientID: mockOIDCConfig.ClientID,
|
||||
ClientSecret: mockOIDCConfig.ClientSecret,
|
||||
KeyPath: "",
|
||||
Issuer: mockOIDCConfig.Issuer,
|
||||
Scopes: []string{"openid", "email"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
defaultVal := true
|
||||
apiKeyConfig := &extconf.APIKeyConfig{
|
||||
BaseConfig: extconf.BaseConfig{Enable: &defaultVal},
|
||||
}
|
||||
conf.Extensions = &extconf.ExtensionConfig{
|
||||
APIKey: apiKeyConfig,
|
||||
}
|
||||
|
||||
ctlr := api.NewController(conf)
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
@@ -50,6 +101,52 @@ func TestRoutes(t *testing.T) {
|
||||
// NOTE: the url or method itself doesn't matter below since we are calling the handlers directly,
|
||||
// so path routing is bypassed
|
||||
|
||||
Convey("Test GithubCodeExchangeCallback", func() {
|
||||
callback := rthdlr.GithubCodeExchangeCallback()
|
||||
ctx := context.TODO()
|
||||
|
||||
request, _ := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
||||
response := httptest.NewRecorder()
|
||||
|
||||
tokens := &oidc.Tokens{}
|
||||
relyingParty, err := rp.NewRelyingPartyOAuth(&oauth2.Config{})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
callback(response, request, tokens, "state", relyingParty)
|
||||
|
||||
resp := response.Result()
|
||||
defer resp.Body.Close()
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
Convey("Test OAuth2Callback errors", func() {
|
||||
ctx := context.TODO()
|
||||
|
||||
request, _ := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
||||
response := httptest.NewRecorder()
|
||||
|
||||
_, err := api.OAuth2Callback(ctlr, response, request, "state", "email", []string{"group"})
|
||||
So(err, ShouldEqual, zerr.ErrInvalidStateCookie)
|
||||
|
||||
session, _ := ctlr.CookieStore.Get(request, "statecookie")
|
||||
|
||||
session.Options.Secure = true
|
||||
session.Options.HttpOnly = true
|
||||
session.Options.SameSite = http.SameSiteDefaultMode
|
||||
|
||||
state := uuid.New().String()
|
||||
|
||||
session.Values["state"] = state
|
||||
|
||||
// let the session set its own id
|
||||
err = session.Save(request, response)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = api.OAuth2Callback(ctlr, response, request, "state", "email", []string{"group"})
|
||||
So(err, ShouldEqual, zerr.ErrInvalidStateCookie)
|
||||
})
|
||||
|
||||
Convey("List repositories authz error", func() {
|
||||
var invalid struct{}
|
||||
|
||||
@@ -575,7 +672,7 @@ func TestRoutes(t *testing.T) {
|
||||
},
|
||||
&mocks.MockedImageStore{
|
||||
FullBlobUploadFn: func(repo string, body io.Reader, digest godigest.Digest) (string, int64, error) {
|
||||
return "session", 0, zerr.ErrBadBlobDigest
|
||||
return sessionStr, 0, zerr.ErrBadBlobDigest
|
||||
},
|
||||
})
|
||||
So(statusCode, ShouldEqual, http.StatusInternalServerError)
|
||||
@@ -591,7 +688,7 @@ func TestRoutes(t *testing.T) {
|
||||
},
|
||||
&mocks.MockedImageStore{
|
||||
FullBlobUploadFn: func(repo string, body io.Reader, digest godigest.Digest) (string, int64, error) {
|
||||
return "session", 20, nil
|
||||
return sessionStr, 20, nil
|
||||
},
|
||||
})
|
||||
So(statusCode, ShouldEqual, http.StatusInternalServerError)
|
||||
@@ -1327,6 +1424,80 @@ func TestRoutes(t *testing.T) {
|
||||
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
||||
})
|
||||
|
||||
Convey("Test API keys", func() {
|
||||
var invalid struct{}
|
||||
|
||||
ctx := context.TODO()
|
||||
key := localCtx.GetContextKey()
|
||||
ctx = context.WithValue(ctx, key, invalid)
|
||||
|
||||
request, _ := http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader([]byte{}))
|
||||
response := httptest.NewRecorder()
|
||||
|
||||
extensions.CreateAPIKey(response, request, ctlr.RepoDB, ctlr.CookieStore, ctlr.Log)
|
||||
|
||||
resp := response.Result()
|
||||
defer resp.Body.Close()
|
||||
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: username,
|
||||
}
|
||||
|
||||
ctx = context.TODO()
|
||||
key = localCtx.GetContextKey()
|
||||
ctx = context.WithValue(ctx, key, acCtx)
|
||||
|
||||
request, _ = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader([]byte{}))
|
||||
response = httptest.NewRecorder()
|
||||
|
||||
extensions.CreateAPIKey(response, request, ctlr.RepoDB, ctlr.CookieStore, ctlr.Log)
|
||||
|
||||
resp = response.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
|
||||
|
||||
payload := extensions.APIKeyPayload{
|
||||
Label: "test",
|
||||
Scopes: []string{"test"},
|
||||
}
|
||||
reqBody, err := json.Marshal(payload)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
request, _ = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(reqBody))
|
||||
response = httptest.NewRecorder()
|
||||
|
||||
extensions.CreateAPIKey(response, request, mocks.RepoDBMock{
|
||||
AddUserAPIKeyFn: func(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error {
|
||||
return ErrUnexpectedError
|
||||
},
|
||||
}, ctlr.CookieStore, ctlr.Log)
|
||||
|
||||
resp = response.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
|
||||
|
||||
request, _ = http.NewRequestWithContext(ctx, http.MethodDelete, baseURL, bytes.NewReader([]byte{}))
|
||||
response = httptest.NewRecorder()
|
||||
|
||||
q := request.URL.Query()
|
||||
q.Add("id", "apikeyid")
|
||||
request.URL.RawQuery = q.Encode()
|
||||
|
||||
extensions.RevokeAPIKey(response, request, mocks.RepoDBMock{
|
||||
DeleteUserAPIKeyFn: func(ctx context.Context, id string) error {
|
||||
return ErrUnexpectedError
|
||||
},
|
||||
}, ctlr.CookieStore, ctlr.Log)
|
||||
|
||||
resp = response.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
|
||||
})
|
||||
|
||||
Convey("Helper functions", func() {
|
||||
testUpdateBlobUpload := func(
|
||||
query []struct{ k, v string },
|
||||
|
||||
Reference in New Issue
Block a user