mirror of
https://github.com/project-zot/zot.git
synced 2026-06-16 20:38:08 +08:00
feat: integrate openID auth logic and user profile management (#1381)
This change introduces OpenID authn by using providers such as Github, Gitlab, Google and Dex. User sessions are now used for web clients to identify and persist an authenticated users session, thus not requiring every request to use credentials. Another change is apikey feature, users can create/revoke their api keys and use them to authenticate when using cli clients such as skopeo. eg: login: /auth/login?provider=github /auth/login?provider=gitlab and so on logout: /auth/logout redirectURL: /auth/callback/github /auth/callback/gitlab and so on If network policy doesn't allow inbound connections, this callback wont work! for more info read documentation added in this commit. Signed-off-by: Alex Stan <alexandrustan96@yahoo.ro> Signed-off-by: Petu Eusebiu <peusebiu@cisco.com> Co-authored-by: Alex Stan <alexandrustan96@yahoo.ro>
This commit is contained in:
+678
-163
@@ -3,139 +3,315 @@ package api
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/chartmuseum/auth"
|
||||
"github.com/google/go-github/v52/github"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
godigest "github.com/opencontainers/go-digest"
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/oauth2"
|
||||
githubOAuth "golang.org/x/oauth2/github"
|
||||
|
||||
"zotregistry.io/zot/errors"
|
||||
zerr "zotregistry.io/zot/errors"
|
||||
"zotregistry.io/zot/pkg/api/config"
|
||||
"zotregistry.io/zot/pkg/api/constants"
|
||||
apiErr "zotregistry.io/zot/pkg/api/errors"
|
||||
"zotregistry.io/zot/pkg/common"
|
||||
"zotregistry.io/zot/pkg/log"
|
||||
localCtx "zotregistry.io/zot/pkg/requestcontext"
|
||||
storageConstants "zotregistry.io/zot/pkg/storage/constants"
|
||||
)
|
||||
|
||||
const (
|
||||
bearerAuthDefaultAccessEntryType = "repository"
|
||||
issuedAtOffset = 5 * time.Second
|
||||
relyingPartyCookieMaxAge = 120
|
||||
)
|
||||
|
||||
func AuthHandler(c *Controller) mux.MiddlewareFunc {
|
||||
if isBearerAuthEnabled(c.Config) {
|
||||
return bearerAuthHandler(c)
|
||||
}
|
||||
|
||||
return basicAuthHandler(c)
|
||||
type AuthnMiddleware struct {
|
||||
credMap map[string]string
|
||||
ldapClient *LDAPClient
|
||||
}
|
||||
|
||||
func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
authorizer, err := auth.NewAuthorizer(&auth.AuthorizerOptions{
|
||||
Realm: ctlr.Config.HTTP.Auth.Bearer.Realm,
|
||||
Service: ctlr.Config.HTTP.Auth.Bearer.Service,
|
||||
PublicKeyPath: ctlr.Config.HTTP.Auth.Bearer.Cert,
|
||||
AccessEntryType: bearerAuthDefaultAccessEntryType,
|
||||
EmptyDefaultNamespace: true,
|
||||
})
|
||||
func AuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
authnMiddleware := &AuthnMiddleware{}
|
||||
|
||||
if isBearerAuthEnabled(ctlr.Config) {
|
||||
return bearerAuthHandler(ctlr)
|
||||
}
|
||||
|
||||
return authnMiddleware.TryAuthnHandlers(ctlr)
|
||||
}
|
||||
|
||||
func (amw *AuthnMiddleware) sessionAuthn(ctlr *Controller, next http.Handler, response http.ResponseWriter,
|
||||
request *http.Request, delay int,
|
||||
) {
|
||||
clientHeader := request.Header.Get(constants.SessionClientHeaderName)
|
||||
if clientHeader != constants.SessionClientHeaderValue {
|
||||
authFail(response, request, ctlr.Config.HTTP.Realm, delay)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
identity, ok := common.GetAuthUserFromRequestSession(ctlr.CookieStore, request, ctlr.Log)
|
||||
if !ok {
|
||||
// let the client know that this session is invalid/expired
|
||||
cookie := &http.Cookie{
|
||||
Name: "session",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
|
||||
HttpOnly: true,
|
||||
}
|
||||
|
||||
http.SetCookie(response, cookie)
|
||||
|
||||
authFail(response, request, ctlr.Config.HTTP.Realm, delay)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getReqContextWithAuthorization(identity, []string{}, request)
|
||||
|
||||
groups, err := ctlr.RepoDB.GetUserGroups(ctx)
|
||||
if err != nil {
|
||||
ctlr.Log.Panic().Err(err).Msg("error creating bearer authorizer")
|
||||
if errors.Is(err, zerr.ErrUserDataNotFound) {
|
||||
ctlr.Log.Err(err).Str("identity", identity).Msg("can not find user profile in DB")
|
||||
|
||||
authFail(response, request, ctlr.Config.HTTP.Realm, delay)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctlr.Log.Err(err).Str("identity", identity).Msg("can not get user profile in DB")
|
||||
|
||||
response.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
if request.Method == http.MethodOptions {
|
||||
next.ServeHTTP(response, request)
|
||||
response.WriteHeader(http.StatusNoContent)
|
||||
ctx = getReqContextWithAuthorization(identity, groups, request)
|
||||
|
||||
return
|
||||
}
|
||||
vars := mux.Vars(request)
|
||||
name := vars["name"]
|
||||
|
||||
// we want to bypass auth for mgmt route
|
||||
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
|
||||
|
||||
header := request.Header.Get("Authorization")
|
||||
|
||||
if (header == "" || header == "Basic Og==") && isMgmtRequested {
|
||||
next.ServeHTTP(response, request)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
action := auth.PullAction
|
||||
if m := request.Method; m != http.MethodGet && m != http.MethodHead {
|
||||
action = auth.PushAction
|
||||
}
|
||||
permissions, err := authorizer.Authorize(header, action, name)
|
||||
if err != nil {
|
||||
ctlr.Log.Error().Err(err).Msg("issue parsing Authorization header")
|
||||
response.Header().Set("Content-Type", "application/json")
|
||||
common.WriteJSON(response, http.StatusInternalServerError, apiErr.NewErrorList(apiErr.NewError(apiErr.UNSUPPORTED)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !permissions.Allowed {
|
||||
authFail(response, permissions.WWWAuthenticateHeader, 0)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(response, request)
|
||||
})
|
||||
}
|
||||
next.ServeHTTP(response, request.WithContext(ctx))
|
||||
}
|
||||
|
||||
func noPasswdAuth(realm string, config *config.Config) mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
if request.Method == http.MethodOptions {
|
||||
next.ServeHTTP(response, request)
|
||||
response.WriteHeader(http.StatusNoContent)
|
||||
func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, response http.ResponseWriter,
|
||||
request *http.Request,
|
||||
) (bool, http.ResponseWriter, *http.Request, error) {
|
||||
cookieStore := ctlr.CookieStore
|
||||
|
||||
return
|
||||
}
|
||||
// we want to bypass auth for mgmt route
|
||||
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
|
||||
|
||||
// Process request
|
||||
if request.Header.Get("Authorization") == "" {
|
||||
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
|
||||
ctx := getReqContextWithAuthorization("", []string{}, request)
|
||||
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
|
||||
})
|
||||
// Process request
|
||||
|
||||
return true, response, request.WithContext(ctx), nil
|
||||
}
|
||||
}
|
||||
|
||||
identity, passphrase, err := getUsernamePasswordBasicAuth(request)
|
||||
if err != nil {
|
||||
ctlr.Log.Error().Err(err).Msg("failed to parse authorization header")
|
||||
|
||||
return false, nil, nil, nil
|
||||
}
|
||||
|
||||
// some client tools might send Authorization: Basic Og== (decoded into ":")
|
||||
// empty username and password
|
||||
if identity == "" && passphrase == "" {
|
||||
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
|
||||
ctx := getReqContextWithAuthorization("", []string{}, request)
|
||||
|
||||
return true, response, request.WithContext(ctx), nil
|
||||
}
|
||||
}
|
||||
|
||||
passphraseHash, ok := amw.credMap[identity]
|
||||
if ok {
|
||||
// first, HTTPPassword authN (which is local)
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(passphraseHash), []byte(passphrase)); err == nil {
|
||||
// Process request
|
||||
var groups []string
|
||||
|
||||
if ctlr.Config.HTTP.AccessControl != nil {
|
||||
ac := NewAccessController(ctlr.Config)
|
||||
groups = ac.getUserGroups(identity)
|
||||
}
|
||||
|
||||
ctx := getReqContextWithAuthorization(identity, groups, request)
|
||||
|
||||
// saved logged session
|
||||
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
|
||||
return false, response, request, err
|
||||
}
|
||||
|
||||
if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil {
|
||||
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("couldn't update user profile")
|
||||
|
||||
return false, response, request, err
|
||||
}
|
||||
|
||||
ctlr.Log.Info().Str("identity", identity).Msgf("user profile successfully set")
|
||||
|
||||
return true, response, request.WithContext(ctx), nil
|
||||
}
|
||||
}
|
||||
|
||||
// next, LDAP if configured (network-based which can lose connectivity)
|
||||
if ctlr.Config.HTTP.Auth != nil && ctlr.Config.HTTP.Auth.LDAP != nil {
|
||||
ok, _, ldapgroups, err := amw.ldapClient.Authenticate(identity, passphrase)
|
||||
if ok && err == nil {
|
||||
// Process request
|
||||
var groups []string
|
||||
|
||||
if ctlr.Config.HTTP.AccessControl != nil {
|
||||
ac := NewAccessController(ctlr.Config)
|
||||
groups = ac.getUserGroups(identity)
|
||||
}
|
||||
|
||||
groups = append(groups, ldapgroups...)
|
||||
|
||||
ctx := getReqContextWithAuthorization(identity, groups, request)
|
||||
|
||||
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
|
||||
return false, response, request, err
|
||||
}
|
||||
|
||||
if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil {
|
||||
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("couldn't update user profile")
|
||||
|
||||
return false, response, request, err
|
||||
}
|
||||
|
||||
return true, response, request.WithContext(ctx), nil
|
||||
}
|
||||
}
|
||||
|
||||
// last try API keys
|
||||
if isAPIKeyEnabled(ctlr.Config) {
|
||||
apiKey := passphrase
|
||||
|
||||
if !strings.HasPrefix(apiKey, constants.APIKeysPrefix) {
|
||||
ctlr.Log.Error().Msg("api token has invalid format")
|
||||
|
||||
return false, nil, nil, nil
|
||||
}
|
||||
|
||||
trimmedAPIKey := strings.TrimPrefix(apiKey, constants.APIKeysPrefix)
|
||||
|
||||
hashedKey := hashUUID(trimmedAPIKey)
|
||||
|
||||
storedIdentity, err := ctlr.RepoDB.GetUserAPIKeyInfo(hashedKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, zerr.ErrUserAPIKeyNotFound) {
|
||||
ctlr.Log.Info().Err(err).Msgf("can not find any user info for hashed key %s in DB", hashedKey)
|
||||
|
||||
return false, nil, nil, nil
|
||||
}
|
||||
|
||||
ctlr.Log.Error().Err(err).Msgf("can not get user info for hashed key %s in DB", hashedKey)
|
||||
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
if storedIdentity == identity {
|
||||
ctx := getReqContextWithAuthorization(identity, []string{}, request)
|
||||
|
||||
err := ctlr.RepoDB.UpdateUserAPIKeyLastUsed(ctx, hashedKey)
|
||||
if err != nil {
|
||||
ctlr.Log.Err(err).Str("identity", identity).Msg("can not update user profile in DB")
|
||||
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
groups, err := ctlr.RepoDB.GetUserGroups(ctx)
|
||||
if err != nil {
|
||||
ctlr.Log.Err(err).Str("identity", identity).Msg("can not get user's groups in DB")
|
||||
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
ctx = getReqContextWithAuthorization(identity, groups, request)
|
||||
|
||||
return true, response, request.WithContext(ctx), nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil, nil, nil
|
||||
}
|
||||
|
||||
//nolint:gocyclo // we use closure making this a complex subroutine
|
||||
func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
realm := ctlr.Config.HTTP.Realm
|
||||
if realm == "" {
|
||||
realm = "Authorization Required"
|
||||
}
|
||||
|
||||
realm = "Basic realm=" + strconv.Quote(realm)
|
||||
|
||||
func (amw *AuthnMiddleware) TryAuthnHandlers(ctlr *Controller) mux.MiddlewareFunc { //nolint: gocyclo
|
||||
// no password based authN, if neither LDAP nor HTTP BASIC is enabled
|
||||
if ctlr.Config.HTTP.Auth == nil ||
|
||||
(ctlr.Config.HTTP.Auth.HTPasswd.Path == "" && ctlr.Config.HTTP.Auth.LDAP == nil) {
|
||||
return noPasswdAuth(realm, ctlr.Config)
|
||||
(ctlr.Config.HTTP.Auth.HTPasswd.Path == "" && ctlr.Config.HTTP.Auth.LDAP == nil &&
|
||||
ctlr.Config.HTTP.Auth.OpenID == nil) {
|
||||
return noPasswdAuth(ctlr.Config)
|
||||
}
|
||||
|
||||
credMap := make(map[string]string)
|
||||
amw.credMap = make(map[string]string)
|
||||
|
||||
delay := ctlr.Config.HTTP.Auth.FailDelay
|
||||
|
||||
var ldapClient *LDAPClient
|
||||
// setup sessions cookie store used to preserve logged in user in web sessions
|
||||
if isAuthnEnabled(ctlr.Config) || isOpenIDAuthEnabled(ctlr.Config) {
|
||||
// To store custom types in our cookies,
|
||||
// we must first register them using gob.Register
|
||||
gob.Register(map[string]interface{}{})
|
||||
|
||||
cookieStoreHashKey := securecookie.GenerateRandomKey(64)
|
||||
if cookieStoreHashKey == nil {
|
||||
panic(zerr.ErrHashKeyNotCreated)
|
||||
}
|
||||
|
||||
// if storage is filesystem then use zot's rootDir to store sessions
|
||||
if ctlr.Config.Storage.StorageDriver == nil {
|
||||
sessionsDir := path.Join(ctlr.Config.Storage.RootDirectory, "_sessions")
|
||||
if err := os.MkdirAll(sessionsDir, storageConstants.DefaultDirPerms); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
cookieStore := sessions.NewFilesystemStore(sessionsDir, cookieStoreHashKey)
|
||||
|
||||
cookieStore.MaxAge(cookiesMaxAge)
|
||||
|
||||
ctlr.CookieStore = cookieStore
|
||||
} else {
|
||||
cookieStore := sessions.NewCookieStore(cookieStoreHashKey)
|
||||
|
||||
cookieStore.MaxAge(cookiesMaxAge)
|
||||
|
||||
ctlr.CookieStore = cookieStore
|
||||
}
|
||||
}
|
||||
|
||||
// ldap and htpasswd based authN
|
||||
if ctlr.Config.HTTP.Auth != nil {
|
||||
if ctlr.Config.HTTP.Auth.LDAP != nil {
|
||||
ldapConfig := ctlr.Config.HTTP.Auth.LDAP
|
||||
ldapClient = &LDAPClient{
|
||||
amw.ldapClient = &LDAPClient{
|
||||
Host: ldapConfig.Address,
|
||||
Port: ldapConfig.Port,
|
||||
UseSSL: !ldapConfig.Insecure,
|
||||
@@ -160,18 +336,18 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
caCertPool := x509.NewCertPool()
|
||||
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
panic(errors.ErrBadCACert)
|
||||
panic(zerr.ErrBadCACert)
|
||||
}
|
||||
|
||||
ldapClient.ClientCAs = caCertPool
|
||||
amw.ldapClient.ClientCAs = caCertPool
|
||||
} else {
|
||||
// default to system cert pool
|
||||
caCertPool, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
panic(errors.ErrBadCACert)
|
||||
panic(zerr.ErrBadCACert)
|
||||
}
|
||||
|
||||
ldapClient.ClientCAs = caCertPool
|
||||
amw.ldapClient.ClientCAs = caCertPool
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,12 +364,27 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
line := scanner.Text()
|
||||
if strings.Contains(line, ":") {
|
||||
tokens := strings.Split(scanner.Text(), ":")
|
||||
credMap[tokens[0]] = tokens[1]
|
||||
amw.credMap[tokens[0]] = tokens[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// openid based authN
|
||||
if ctlr.Config.HTTP.Auth.OpenID != nil {
|
||||
ctlr.RelyingParties = make(map[string]rp.RelyingParty)
|
||||
|
||||
for provider := range ctlr.Config.HTTP.Auth.OpenID.Providers {
|
||||
if IsOpenIDSupported(provider) {
|
||||
rp := NewRelyingPartyOIDC(ctlr.Config, provider)
|
||||
ctlr.RelyingParties[provider] = rp
|
||||
} else if IsOauth2Supported(provider) {
|
||||
rp := NewRelyingPartyGithub(ctlr.Config, provider)
|
||||
ctlr.RelyingParties[provider] = rp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
if request.Method == http.MethodOptions {
|
||||
@@ -203,84 +394,231 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// we want to bypass auth for mgmt route
|
||||
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
|
||||
|
||||
if request.Header.Get("Authorization") == "" {
|
||||
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
|
||||
// Process request
|
||||
ctx := getReqContextWithAuthorization("", []string{}, request)
|
||||
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
username, passphrase, err := getUsernamePasswordBasicAuth(request)
|
||||
//nolint: contextcheck
|
||||
authenticated, cloneResp, cloneReq, err := amw.basicAuthn(ctlr, response, request)
|
||||
if err != nil {
|
||||
ctlr.Log.Error().Err(err).Msg("failed to parse authorization header")
|
||||
authFail(response, realm, delay)
|
||||
response.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// some client tools might send Authorization: Basic Og== (decoded into ":")
|
||||
// empty username and password
|
||||
if username == "" && passphrase == "" {
|
||||
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
|
||||
// Process request
|
||||
ctx := getReqContextWithAuthorization("", []string{}, request)
|
||||
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
|
||||
if authenticated && cloneResp != nil && cloneReq != nil {
|
||||
next.ServeHTTP(cloneResp, cloneReq)
|
||||
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// first, HTTPPassword authN (which is local)
|
||||
passphraseHash, ok := credMap[username]
|
||||
if ok {
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(passphraseHash), []byte(passphrase)); err == nil {
|
||||
// Process request
|
||||
var userGroups []string
|
||||
|
||||
if ctlr.Config.HTTP.AccessControl != nil {
|
||||
ac := NewAccessController(ctlr.Config)
|
||||
userGroups = ac.getUserGroups(username)
|
||||
}
|
||||
|
||||
ctx := getReqContextWithAuthorization(username, userGroups, request)
|
||||
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// next, LDAP if configured (network-based which can lose connectivity)
|
||||
if ctlr.Config.HTTP.Auth != nil && ctlr.Config.HTTP.Auth.LDAP != nil {
|
||||
ok, _, ldapgroups, err := ldapClient.Authenticate(username, passphrase)
|
||||
if ok && err == nil {
|
||||
// Process request
|
||||
var userGroups []string
|
||||
|
||||
if ctlr.Config.HTTP.AccessControl != nil {
|
||||
ac := NewAccessController(ctlr.Config)
|
||||
userGroups = ac.getUserGroups(username)
|
||||
}
|
||||
|
||||
userGroups = append(userGroups, ldapgroups...)
|
||||
|
||||
ctx := getReqContextWithAuthorization(username, userGroups, request)
|
||||
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
authFail(response, realm, delay)
|
||||
//nolint: contextcheck
|
||||
amw.sessionAuthn(ctlr, next, response, request, delay)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
authorizer, err := auth.NewAuthorizer(&auth.AuthorizerOptions{
|
||||
Realm: ctlr.Config.HTTP.Auth.Bearer.Realm,
|
||||
Service: ctlr.Config.HTTP.Auth.Bearer.Service,
|
||||
PublicKeyPath: ctlr.Config.HTTP.Auth.Bearer.Cert,
|
||||
AccessEntryType: bearerAuthDefaultAccessEntryType,
|
||||
EmptyDefaultNamespace: true,
|
||||
})
|
||||
if err != nil {
|
||||
ctlr.Log.Panic().Err(err).Msg("error creating bearer authorizer")
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
if request.Method == http.MethodOptions {
|
||||
next.ServeHTTP(response, request)
|
||||
response.WriteHeader(http.StatusNoContent)
|
||||
|
||||
return
|
||||
}
|
||||
acCtrlr := NewAccessController(ctlr.Config)
|
||||
vars := mux.Vars(request)
|
||||
name := vars["name"]
|
||||
|
||||
// we want to bypass auth for mgmt route
|
||||
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
|
||||
|
||||
header := request.Header.Get("Authorization")
|
||||
|
||||
if (header == "" || header == "Basic Og==") && isMgmtRequested {
|
||||
next.ServeHTTP(response, request)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
action := auth.PullAction
|
||||
if m := request.Method; m != http.MethodGet && m != http.MethodHead {
|
||||
action = auth.PushAction
|
||||
}
|
||||
|
||||
permissions, err := authorizer.Authorize(header, action, name)
|
||||
if err != nil {
|
||||
ctlr.Log.Error().Err(err).Msg("issue parsing Authorization header")
|
||||
response.Header().Set("Content-Type", "application/json")
|
||||
common.WriteJSON(response, http.StatusInternalServerError, apiErr.NewErrorList(apiErr.NewError(apiErr.UNSUPPORTED)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !permissions.Allowed {
|
||||
response.Header().Set("Content-Type", "application/json")
|
||||
response.Header().Set("WWW-Authenticate", permissions.WWWAuthenticateHeader)
|
||||
|
||||
common.WriteJSON(response, http.StatusUnauthorized,
|
||||
apiErr.NewErrorList(apiErr.NewError(apiErr.UNAUTHORIZED)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
amCtx := acCtrlr.getAuthnMiddlewareContext(BEARER, request)
|
||||
next.ServeHTTP(response, request.WithContext(amCtx)) //nolint:contextcheck
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func noPasswdAuth(config *config.Config) mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
if request.Method == http.MethodOptions {
|
||||
next.ServeHTTP(response, request)
|
||||
response.WriteHeader(http.StatusNoContent)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getReqContextWithAuthorization("", []string{}, request)
|
||||
// Process request
|
||||
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (rh *RouteHandler) AuthURLHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
callbackUI := query.Get(constants.CallbackUIQueryParam)
|
||||
|
||||
provider := query.Get("provider")
|
||||
|
||||
client, ok := rh.c.RelyingParties[provider]
|
||||
if !ok {
|
||||
http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
response.WriteHeader(http.StatusBadRequest)
|
||||
})(w, r)
|
||||
}
|
||||
|
||||
/* save cookie containing state to later verify it and
|
||||
callback ui where we will redirect after openid/oauth2 logic is completed*/
|
||||
session, _ := rh.c.CookieStore.Get(r, "statecookie")
|
||||
|
||||
session.Options.Secure = true
|
||||
session.Options.HttpOnly = true
|
||||
session.Options.SameSite = http.SameSiteDefaultMode
|
||||
session.Options.Path = constants.CallbackBasePath
|
||||
|
||||
state := uuid.New().String()
|
||||
|
||||
session.Values["state"] = state
|
||||
session.Values["callback"] = callbackUI
|
||||
|
||||
// let the session set its own id
|
||||
err := session.Save(r, w)
|
||||
if err != nil {
|
||||
rh.c.Log.Error().Err(err).Msg("unable to save http session")
|
||||
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
stateFunc := func() string {
|
||||
return state
|
||||
}
|
||||
|
||||
rp.AuthURLHandler(stateFunc, client)(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func NewRelyingPartyOIDC(config *config.Config, provider string) rp.RelyingParty {
|
||||
issuer, clientID, clientSecret, redirectURI, scopes, options := getRelyingPartyArgs(config, provider)
|
||||
|
||||
relyingParty, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return relyingParty
|
||||
}
|
||||
|
||||
func NewRelyingPartyGithub(config *config.Config, provider string) rp.RelyingParty {
|
||||
_, clientID, clientSecret, redirectURI, scopes, options := getRelyingPartyArgs(config, provider)
|
||||
|
||||
rpConfig := &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURI,
|
||||
Scopes: scopes,
|
||||
Endpoint: githubOAuth.Endpoint,
|
||||
}
|
||||
|
||||
relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return relyingParty
|
||||
}
|
||||
|
||||
func getRelyingPartyArgs(config *config.Config, provider string) (
|
||||
string, string, string, string, []string, []rp.Option,
|
||||
) {
|
||||
if _, ok := config.HTTP.Auth.OpenID.Providers[provider]; !ok {
|
||||
panic(zerr.ErrOpenIDProviderDoesNotExist)
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if config.HTTP.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
clientID := config.HTTP.Auth.OpenID.Providers[provider].ClientID
|
||||
clientSecret := config.HTTP.Auth.OpenID.Providers[provider].ClientSecret
|
||||
|
||||
scopes := config.HTTP.Auth.OpenID.Providers[provider].Scopes
|
||||
// openid scope must be the first one in list
|
||||
if !common.Contains(scopes, oidc.ScopeOpenID) && IsOpenIDSupported(provider) {
|
||||
scopes = append([]string{oidc.ScopeOpenID}, scopes...)
|
||||
}
|
||||
|
||||
port := config.HTTP.Port
|
||||
issuer := config.HTTP.Auth.OpenID.Providers[provider].Issuer
|
||||
keyPath := config.HTTP.Auth.OpenID.Providers[provider].KeyPath
|
||||
baseURL := net.JoinHostPort(config.HTTP.Address, port)
|
||||
redirectURI := fmt.Sprintf("%s://%s%s", scheme, baseURL, constants.CallbackBasePath+fmt.Sprintf("/%s", provider))
|
||||
|
||||
options := []rp.Option{
|
||||
rp.WithVerifierOpts(rp.WithIssuedAtOffset(issuedAtOffset)),
|
||||
}
|
||||
|
||||
key := securecookie.GenerateRandomKey(32) //nolint: gomnd
|
||||
|
||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithMaxAge(relyingPartyCookieMaxAge))
|
||||
options = append(options, rp.WithCookieHandler(cookieHandler))
|
||||
|
||||
if clientSecret == "" {
|
||||
options = append(options, rp.WithPKCE(cookieHandler))
|
||||
}
|
||||
|
||||
if keyPath != "" {
|
||||
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
|
||||
}
|
||||
|
||||
return issuer, clientID, clientSecret, redirectURI, scopes, options
|
||||
}
|
||||
|
||||
func getReqContextWithAuthorization(username string, groups []string, request *http.Request) context.Context {
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: username,
|
||||
@@ -314,9 +652,71 @@ func isBearerAuthEnabled(config *config.Config) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func authFail(w http.ResponseWriter, realm string, delay int) {
|
||||
func isOpenIDAuthEnabled(config *config.Config) bool {
|
||||
if config.HTTP.Auth != nil &&
|
||||
config.HTTP.Auth.OpenID != nil {
|
||||
for provider := range config.HTTP.Auth.OpenID.Providers {
|
||||
if isOpenIDAuthProviderEnabled(config, provider) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isAPIKeyEnabled(config *config.Config) bool {
|
||||
if config.Extensions != nil && config.Extensions.APIKey != nil &&
|
||||
*config.Extensions.APIKey.Enable {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isOpenIDAuthProviderEnabled(config *config.Config, provider string) bool {
|
||||
if providerConfig, ok := config.HTTP.Auth.OpenID.Providers[provider]; ok {
|
||||
if IsOpenIDSupported(provider) {
|
||||
if providerConfig.ClientID != "" || providerConfig.Issuer != "" ||
|
||||
len(providerConfig.Scopes) > 0 {
|
||||
return true
|
||||
}
|
||||
} else if IsOauth2Supported(provider) {
|
||||
if providerConfig.ClientID != "" || len(providerConfig.Scopes) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func IsOpenIDSupported(provider string) bool {
|
||||
supported := []string{"google", "gitlab", "dex"}
|
||||
|
||||
return common.Contains(supported, provider)
|
||||
}
|
||||
|
||||
func IsOauth2Supported(provider string) bool {
|
||||
supported := []string{"github"}
|
||||
|
||||
return common.Contains(supported, provider)
|
||||
}
|
||||
|
||||
func authFail(w http.ResponseWriter, r *http.Request, realm string, delay int) {
|
||||
time.Sleep(time.Duration(delay) * time.Second)
|
||||
w.Header().Set("WWW-Authenticate", realm)
|
||||
|
||||
// don't send auth headers if request is coming from UI
|
||||
if r.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue {
|
||||
if realm == "" {
|
||||
realm = "Authorization Required"
|
||||
}
|
||||
|
||||
realm = "Basic realm=" + strconv.Quote(realm)
|
||||
|
||||
w.Header().Set("WWW-Authenticate", realm)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
common.WriteJSON(w, http.StatusUnauthorized, apiErr.NewErrorList(apiErr.NewError(apiErr.UNAUTHORIZED)))
|
||||
}
|
||||
@@ -325,12 +725,12 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error)
|
||||
basicAuth := request.Header.Get("Authorization")
|
||||
|
||||
if basicAuth == "" {
|
||||
return "", "", errors.ErrParsingAuthHeader
|
||||
return "", "", zerr.ErrParsingAuthHeader
|
||||
}
|
||||
|
||||
splitStr := strings.SplitN(basicAuth, " ", 2) //nolint:gomnd
|
||||
splitStr := strings.SplitN(basicAuth, " ", 2) //nolint: gomnd
|
||||
if len(splitStr) != 2 || strings.ToLower(splitStr[0]) != "basic" {
|
||||
return "", "", errors.ErrParsingAuthHeader
|
||||
return "", "", zerr.ErrParsingAuthHeader
|
||||
}
|
||||
|
||||
decodedStr, err := base64.StdEncoding.DecodeString(splitStr[1])
|
||||
@@ -338,9 +738,9 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error)
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
pair := strings.SplitN(string(decodedStr), ":", 2) //nolint:gomnd
|
||||
if len(pair) != 2 { //nolint:gomnd
|
||||
return "", "", errors.ErrParsingAuthHeader
|
||||
pair := strings.SplitN(string(decodedStr), ":", 2) //nolint: gomnd
|
||||
if len(pair) != 2 { //nolint: gomnd
|
||||
return "", "", zerr.ErrParsingAuthHeader
|
||||
}
|
||||
|
||||
username := pair[0]
|
||||
@@ -348,3 +748,118 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error)
|
||||
|
||||
return username, passphrase, nil
|
||||
}
|
||||
|
||||
func GetGithubUserInfo(ctx context.Context, client *github.Client, log log.Logger) (string, []string, error) {
|
||||
var primaryEmail string
|
||||
|
||||
userEmails, _, err := client.Users.ListEmails(ctx, nil)
|
||||
if err != nil {
|
||||
log.Error().Msg("couldn't set user record for empty email value")
|
||||
|
||||
return "", []string{}, err
|
||||
}
|
||||
|
||||
if len(userEmails) != 0 {
|
||||
for _, email := range userEmails { // should have at least one primary email, if any
|
||||
if email.GetPrimary() { // check if it's primary email
|
||||
primaryEmail = email.GetEmail()
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
orgs, _, err := client.Organizations.List(ctx, "", nil)
|
||||
if err != nil {
|
||||
log.Error().Msg("couldn't set user record for empty email value")
|
||||
|
||||
return "", []string{}, err
|
||||
}
|
||||
|
||||
groups := []string{}
|
||||
for _, org := range orgs {
|
||||
groups = append(groups, *org.Login)
|
||||
}
|
||||
|
||||
return primaryEmail, groups, nil
|
||||
}
|
||||
|
||||
func saveUserLoggedSession(cookieStore sessions.Store, response http.ResponseWriter,
|
||||
request *http.Request, identity string, log log.Logger,
|
||||
) error {
|
||||
session, _ := cookieStore.Get(request, "session")
|
||||
|
||||
session.Options.Secure = true
|
||||
session.Options.HttpOnly = true
|
||||
session.Options.SameSite = http.SameSiteDefaultMode
|
||||
session.Values["authStatus"] = true
|
||||
session.Values["user"] = identity
|
||||
|
||||
// let the session set its own id
|
||||
err := session.Save(request, response)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("identity", identity).Msg("unable to save http session")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
userInfoCookie := sessions.NewCookie("user", identity, &sessions.Options{
|
||||
Secure: true,
|
||||
HttpOnly: false,
|
||||
MaxAge: cookiesMaxAge,
|
||||
SameSite: http.SameSiteDefaultMode,
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
http.SetCookie(response, userInfoCookie)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OAuth2Callback is the callback logic where openid/oauth2 will redirect back to our app.
|
||||
func OAuth2Callback(ctlr *Controller, w http.ResponseWriter, r *http.Request, state, email string,
|
||||
groups []string,
|
||||
) (string, error) {
|
||||
stateCookie, _ := ctlr.CookieStore.Get(r, "statecookie")
|
||||
|
||||
stateOrigin, ok := stateCookie.Values["state"].(string)
|
||||
if !ok {
|
||||
ctlr.Log.Error().Err(zerr.ErrInvalidStateCookie).Msg("openID: unable to get 'state' cookie from request")
|
||||
|
||||
return "", zerr.ErrInvalidStateCookie
|
||||
}
|
||||
|
||||
if stateOrigin != state {
|
||||
ctlr.Log.Error().Err(zerr.ErrInvalidStateCookie).Msg("openID: 'state' cookie differs from the actual one")
|
||||
|
||||
return "", zerr.ErrInvalidStateCookie
|
||||
}
|
||||
|
||||
ctx := getReqContextWithAuthorization(email, groups, r)
|
||||
|
||||
// if this line has been reached, then a new session should be created
|
||||
// if the `session` key is already on the cookie, it's not a valid one
|
||||
if err := saveUserLoggedSession(ctlr.CookieStore, w, r, email, ctlr.Log); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil {
|
||||
ctlr.Log.Error().Err(err).Str("identity", email).Msg("couldn't update the user profile")
|
||||
|
||||
return "", err
|
||||
}
|
||||
|
||||
ctlr.Log.Info().Msgf("user profile set successfully for email %s", email)
|
||||
|
||||
// redirect to UI
|
||||
callbackUI, _ := stateCookie.Values["callback"].(string)
|
||||
|
||||
return callbackUI, nil
|
||||
}
|
||||
|
||||
func hashUUID(uuid string) string {
|
||||
digester := sha256.New()
|
||||
digester.Write([]byte(uuid))
|
||||
|
||||
return godigest.NewDigestFromEncoded(godigest.SHA256, fmt.Sprintf("%x", digester.Sum(nil))).Encoded()
|
||||
}
|
||||
|
||||
+45
-11
@@ -21,6 +21,9 @@ const (
|
||||
Delete = "delete"
|
||||
// behaviour actions.
|
||||
DetectManifestCollision = "detectManifestCollision"
|
||||
BASIC = "Basic"
|
||||
BEARER = "Bearer"
|
||||
OPENID = "OpenID"
|
||||
)
|
||||
|
||||
// AccessController authorizes users to act on resources.
|
||||
@@ -29,10 +32,17 @@ type AccessController struct {
|
||||
Log log.Logger
|
||||
}
|
||||
|
||||
func NewAccessController(config *config.Config) *AccessController {
|
||||
func NewAccessController(conf *config.Config) *AccessController {
|
||||
if conf.HTTP.AccessControl == nil {
|
||||
return &AccessController{
|
||||
Config: &config.AccessControlConfig{},
|
||||
Log: log.NewLogger(conf.Log.Level, conf.Log.Output),
|
||||
}
|
||||
}
|
||||
|
||||
return &AccessController{
|
||||
Config: config.HTTP.AccessControl,
|
||||
Log: log.NewLogger(config.Log.Level, config.Log.Output),
|
||||
Config: conf.HTTP.AccessControl,
|
||||
Log: log.NewLogger(conf.Log.Level, conf.Log.Output),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,6 +181,18 @@ func (ac *AccessController) getContext(acCtx *localCtx.AccessControlContext, req
|
||||
return ctx
|
||||
}
|
||||
|
||||
// getAuthnMiddlewareContext builds ac context(allowed to read repos and if user is admin) and returns it.
|
||||
func (ac *AccessController) getAuthnMiddlewareContext(authnType string, request *http.Request) context.Context {
|
||||
amwCtx := localCtx.AuthnMiddlewareContext{
|
||||
AuthnType: authnType,
|
||||
}
|
||||
|
||||
amwCtxKey := localCtx.GetAuthnMiddlewareCtxKey()
|
||||
ctx := context.WithValue(request.Context(), amwCtxKey, amwCtx)
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// isPermitted returns true if username can do action on a repository policy.
|
||||
func (ac *AccessController) isPermitted(userGroups []string, username, action string,
|
||||
policyGroup config.PolicyGroup,
|
||||
@@ -231,6 +253,14 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// request comes from bearer authn, bypass it
|
||||
authnMwCtx, err := localCtx.GetAuthnMiddlewareContext(request.Context())
|
||||
if err != nil || (authnMwCtx != nil && authnMwCtx.AuthnType == BEARER) {
|
||||
next.ServeHTTP(response, request)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// bypass authz for /v2/ route
|
||||
if request.RequestURI == "/v2/" {
|
||||
next.ServeHTTP(response, request)
|
||||
@@ -242,8 +272,6 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
|
||||
var identity string
|
||||
|
||||
var err error
|
||||
|
||||
// anonymous context
|
||||
acCtx := &localCtx.AccessControlContext{}
|
||||
|
||||
@@ -252,7 +280,7 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
// get access control context made in authn.go if authn is enabled
|
||||
acCtx, err = localCtx.GetAccessControlContext(request.Context())
|
||||
if err != nil { // should never happen
|
||||
authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -272,7 +300,7 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
// if we still don't have an identity
|
||||
if identity == "" {
|
||||
acCtrlr.Log.Info().Msg("couldn't get identity from TLS certificate")
|
||||
authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -298,6 +326,14 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// request comes from bearer authn, bypass it
|
||||
authnMwCtx, err := localCtx.GetAuthnMiddlewareContext(request.Context())
|
||||
if err != nil || (authnMwCtx != nil && authnMwCtx.AuthnType == BEARER) {
|
||||
next.ServeHTTP(response, request)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(request)
|
||||
resource := vars["name"]
|
||||
reference, ok := vars["reference"]
|
||||
@@ -306,12 +342,10 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
|
||||
var identity string
|
||||
|
||||
var err error
|
||||
|
||||
// get acCtx built in authn and previous authz middlewares
|
||||
acCtx, err := localCtx.GetAccessControlContext(request.Context())
|
||||
if err != nil { // should never happen
|
||||
authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -344,7 +378,7 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
|
||||
|
||||
can := acCtrlr.can(request.Context(), identity, action, resource) //nolint:contextcheck
|
||||
if !can {
|
||||
common.AuthzFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
common.AuthzFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
|
||||
} else {
|
||||
next.ServeHTTP(response, request) //nolint:contextcheck
|
||||
}
|
||||
|
||||
@@ -45,6 +45,7 @@ type AuthConfig struct {
|
||||
HTPasswd AuthHTPasswd
|
||||
LDAP *LDAPConfig
|
||||
Bearer *BearerConfig
|
||||
OpenID *OpenIDConfig
|
||||
}
|
||||
|
||||
type BearerConfig struct {
|
||||
@@ -53,6 +54,18 @@ type BearerConfig struct {
|
||||
Cert string
|
||||
}
|
||||
|
||||
type OpenIDConfig struct {
|
||||
Providers map[string]OpenIDProviderConfig
|
||||
}
|
||||
|
||||
type OpenIDProviderConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
KeyPath string
|
||||
Issuer string
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
type MethodRatelimitConfig struct {
|
||||
Method string
|
||||
Rate int
|
||||
@@ -63,6 +76,7 @@ type RatelimitConfig struct {
|
||||
Methods []MethodRatelimitConfig `mapstructure:",omitempty"`
|
||||
}
|
||||
|
||||
//nolint:maligned
|
||||
type HTTPConfig struct {
|
||||
Address string
|
||||
Port string
|
||||
|
||||
@@ -12,4 +12,11 @@ const (
|
||||
DefaultMediaType = "application/json"
|
||||
BinaryMediaType = "application/octet-stream"
|
||||
DefaultMetricsExtensionRoute = "/metrics"
|
||||
CallbackBasePath = "/auth/callback"
|
||||
LoginPath = "/auth/login"
|
||||
LogoutPath = "/auth/logout"
|
||||
SessionClientHeaderName = "X-ZOT-API-CLIENT"
|
||||
SessionClientHeaderValue = "zot-ui"
|
||||
APIKeysPrefix = "zak_"
|
||||
CallbackUIQueryParam = "callback_ui"
|
||||
)
|
||||
|
||||
@@ -18,4 +18,7 @@ const (
|
||||
ExtUserPreferences = "/userprefs"
|
||||
ExtUserPreferencesPrefix = ExtPrefix + ExtUserPreferences
|
||||
FullUserPreferencesPrefix = RoutePrefix + ExtUserPreferencesPrefix
|
||||
ExtAPIKey = "/apikey"
|
||||
ExtAPIKeyPrefix = ExtPrefix + ExtAPIKey //nolint: gosec
|
||||
FullAPIKeyPrefix = RoutePrefix + ExtAPIKeyPrefix
|
||||
)
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
|
||||
"github.com/gorilla/handlers"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
|
||||
"zotregistry.io/zot/errors"
|
||||
"zotregistry.io/zot/pkg/api/config"
|
||||
@@ -31,6 +33,7 @@ import (
|
||||
const (
|
||||
idleTimeout = 120 * time.Second
|
||||
readHeaderTimeout = 5 * time.Second
|
||||
cookiesMaxAge = 86400 // seconds
|
||||
)
|
||||
|
||||
type Controller struct {
|
||||
@@ -44,6 +47,8 @@ type Controller struct {
|
||||
Metrics monitoring.MetricServer
|
||||
CveInfo ext.CveInfo
|
||||
SyncOnDemand SyncOnDemand
|
||||
RelyingParties map[string]rp.RelyingParty
|
||||
CookieStore sessions.Store
|
||||
// runtime params
|
||||
chosenPort int // kernel-chosen port
|
||||
}
|
||||
@@ -254,7 +259,9 @@ func (c *Controller) InitImageStore() error {
|
||||
}
|
||||
|
||||
func (c *Controller) InitRepoDB(reloadCtx context.Context) error {
|
||||
if c.Config.Extensions != nil && c.Config.Extensions.Search != nil && *c.Config.Extensions.Search.Enable {
|
||||
// init repoDB if search is enabled or authn enabled (need to store user profiles) or apikey ext is enabled
|
||||
if (c.Config.Extensions != nil && c.Config.Extensions.Search != nil && *c.Config.Extensions.Search.Enable) ||
|
||||
isAuthnEnabled(c.Config) || isOpenIDAuthEnabled(c.Config) || isAPIKeyEnabled(c.Config) {
|
||||
driver, err := repodbfactory.New(c.Config.Storage.StorageConfig, c.Log) //nolint:contextcheck
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
+1213
-26
File diff suppressed because it is too large
Load Diff
+164
-18
@@ -20,11 +20,14 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/go-github/v52/github"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/opencontainers/distribution-spec/specs-go/v1/extensions"
|
||||
godigest "github.com/opencontainers/go-digest"
|
||||
ispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
artifactspec "github.com/oras-project/artifacts-spec/specs-go/v1"
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
|
||||
zerr "zotregistry.io/zot/errors"
|
||||
"zotregistry.io/zot/pkg/api/constants"
|
||||
@@ -55,13 +58,38 @@ func NewRouteHandler(c *Controller) *RouteHandler {
|
||||
}
|
||||
|
||||
func (rh *RouteHandler) SetupRoutes() {
|
||||
// first get Auth middleware in order to first setup openid/ldap/htpasswd, before oidc provider routes are setup
|
||||
authHandler := AuthHandler(rh.c)
|
||||
|
||||
applyCORSHeaders := getCORSHeadersHandler(rh.c.Config.HTTP.AllowOrigin)
|
||||
|
||||
if isOpenIDAuthEnabled(rh.c.Config) {
|
||||
// login path for openID
|
||||
rh.c.Router.HandleFunc(constants.LoginPath, rh.AuthURLHandler())
|
||||
|
||||
// logout path for openID
|
||||
rh.c.Router.HandleFunc(constants.LogoutPath, applyCORSHeaders(rh.Logout)).
|
||||
Methods(zcommon.AllowedMethods("POST")...)
|
||||
|
||||
// callback path for openID
|
||||
for provider, relyingParty := range rh.c.RelyingParties {
|
||||
if IsOauth2Supported(provider) {
|
||||
rh.c.Router.HandleFunc(constants.CallbackBasePath+fmt.Sprintf("/%s", provider),
|
||||
rp.CodeExchangeHandler(rh.GithubCodeExchangeCallback(), relyingParty))
|
||||
} else if IsOpenIDSupported(provider) {
|
||||
rh.c.Router.HandleFunc(constants.CallbackBasePath+fmt.Sprintf("/%s", provider),
|
||||
rp.CodeExchangeHandler(rp.UserinfoCallback(rh.OpenIDCodeExchangeCallback()), relyingParty))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prefixedRouter := rh.c.Router.PathPrefix(constants.RoutePrefix).Subrouter()
|
||||
prefixedRouter.Use(AuthHandler(rh.c))
|
||||
prefixedRouter.Use(authHandler)
|
||||
|
||||
prefixedDistSpecRouter := prefixedRouter.NewRoute().Subrouter()
|
||||
// authz is being enabled if AccessControl is specified
|
||||
// if Authn is not present AccessControl will have only default policies
|
||||
if rh.c.Config.HTTP.AccessControl != nil && !isBearerAuthEnabled(rh.c.Config) {
|
||||
if rh.c.Config.HTTP.AccessControl != nil {
|
||||
if isAuthnEnabled(rh.c.Config) {
|
||||
rh.c.Log.Info().Msg("access control is being enabled")
|
||||
} else {
|
||||
@@ -72,8 +100,6 @@ func (rh *RouteHandler) SetupRoutes() {
|
||||
prefixedDistSpecRouter.Use(DistSpecAuthzHandler(rh.c))
|
||||
}
|
||||
|
||||
applyCORSHeaders := getCORSHeadersHandler(rh.c.Config.HTTP.AllowOrigin)
|
||||
|
||||
// https://github.com/opencontainers/distribution-spec/blob/main/spec.md#endpoints
|
||||
{
|
||||
prefixedDistSpecRouter.HandleFunc(fmt.Sprintf("/{name:%s}/tags/list", zreg.NameRegexp.String()),
|
||||
@@ -118,7 +144,7 @@ func (rh *RouteHandler) SetupRoutes() {
|
||||
constants.ArtifactSpecRoutePrefix, zreg.NameRegexp.String()), rh.GetOrasReferrers).Methods("GET")
|
||||
|
||||
// swagger
|
||||
debug.SetupSwaggerRoutes(rh.c.Config, rh.c.Router, AuthHandler(rh.c), rh.c.Log)
|
||||
debug.SetupSwaggerRoutes(rh.c.Config, rh.c.Router, authHandler, rh.c.Log)
|
||||
|
||||
// Setup Extensions Routes
|
||||
if rh.c.Config != nil {
|
||||
@@ -135,8 +161,8 @@ func (rh *RouteHandler) SetupRoutes() {
|
||||
rh.c.Log)
|
||||
ext.SetupUserPreferencesRoutes(rh.c.Config, prefixedExtensionsRouter, rh.c.StoreController, rh.c.RepoDB,
|
||||
rh.c.CveInfo, rh.c.Log)
|
||||
|
||||
ext.SetupMetricsRoutes(rh.c.Config, rh.c.Router, rh.c.StoreController, AuthHandler(rh.c), rh.c.Log)
|
||||
ext.SetupAPIKeyRoutes(rh.c.Config, prefixedExtensionsRouter, rh.c.RepoDB, rh.c.CookieStore, rh.c.Log)
|
||||
ext.SetupMetricsRoutes(rh.c.Config, rh.c.Router, rh.c.StoreController, authHandler, rh.c.Log)
|
||||
|
||||
gqlPlayground.SetupGQLPlaygroundRoutes(rh.c.Config, prefixedRouter, rh.c.StoreController, rh.c.Log)
|
||||
|
||||
@@ -185,7 +211,8 @@ func addCORSHeaders(allowOrigin string, response http.ResponseWriter) {
|
||||
// @Success 200 {string} string "ok".
|
||||
func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -195,10 +222,13 @@ func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, reques
|
||||
// NOTE: compatibility workaround - return this header in "allowed-read" mode to allow for clients to
|
||||
// work correctly
|
||||
if rh.c.Config.HTTP.Auth != nil {
|
||||
if rh.c.Config.HTTP.Auth.Bearer != nil {
|
||||
response.Header().Set("WWW-Authenticate", fmt.Sprintf("bearer realm=%s", rh.c.Config.HTTP.Auth.Bearer.Realm))
|
||||
} else {
|
||||
response.Header().Set("WWW-Authenticate", fmt.Sprintf("basic realm=%s", rh.c.Config.HTTP.Realm))
|
||||
// don't send auth headers if request is coming from UI
|
||||
if request.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue {
|
||||
if rh.c.Config.HTTP.Auth.Bearer != nil {
|
||||
response.Header().Set("WWW-Authenticate", fmt.Sprintf("bearer realm=%s", rh.c.Config.HTTP.Auth.Bearer.Realm))
|
||||
} else {
|
||||
response.Header().Set("WWW-Authenticate", fmt.Sprintf("basic realm=%s", rh.c.Config.HTTP.Realm))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,7 +254,8 @@ type ImageTags struct {
|
||||
// @Failure 400 {string} string "bad request".
|
||||
func (rh *RouteHandler) ListTags(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -355,7 +386,8 @@ func (rh *RouteHandler) ListTags(response http.ResponseWriter, request *http.Req
|
||||
// @Failure 500 {string} string "internal server error".
|
||||
func (rh *RouteHandler) CheckManifest(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -427,7 +459,8 @@ type ExtensionList struct {
|
||||
// @Router /v2/{name}/manifests/{reference} [get].
|
||||
func (rh *RouteHandler) GetManifest(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -527,7 +560,8 @@ func getReferrers(routeHandler *RouteHandler,
|
||||
// @Router /v2/{name}/referrers/{digest} [get].
|
||||
func (rh *RouteHandler) GetReferrers(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -1576,7 +1610,8 @@ type RepositoryList struct {
|
||||
// @Router /v2/_catalog [get].
|
||||
func (rh *RouteHandler) ListRepositories(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -1642,7 +1677,8 @@ func (rh *RouteHandler) ListRepositories(response http.ResponseWriter, request *
|
||||
// @Router /v2/_oci/ext/discover [get].
|
||||
func (rh *RouteHandler) ListExtensions(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -1653,6 +1689,116 @@ func (rh *RouteHandler) ListExtensions(w http.ResponseWriter, r *http.Request) {
|
||||
zcommon.WriteJSON(w, http.StatusOK, extensionList)
|
||||
}
|
||||
|
||||
// The following routes are specific to zot and NOT part of the OCI dist-spec
|
||||
|
||||
// Logout godoc
|
||||
// @Summary Logout by removing current session
|
||||
// @Description Logout by removing current session
|
||||
// @Router /openid/auth/logout [post]
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {string} string "ok".
|
||||
// @Failure 500 {string} string "internal server error".
|
||||
func (rh *RouteHandler) Logout(response http.ResponseWriter, request *http.Request) {
|
||||
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
|
||||
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
response.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if request.Method == http.MethodOptions {
|
||||
return
|
||||
}
|
||||
|
||||
session, _ := rh.c.CookieStore.Get(request, "session")
|
||||
session.Options.MaxAge = -1
|
||||
|
||||
err := session.Save(request, response)
|
||||
if err != nil {
|
||||
response.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
response.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// github Oauth2 CodeExchange callback.
|
||||
func (rh *RouteHandler) GithubCodeExchangeCallback() rp.CodeExchangeCallback {
|
||||
return func(w http.ResponseWriter, r *http.Request,
|
||||
tokens *oidc.Tokens, state string, relyingParty rp.RelyingParty,
|
||||
) {
|
||||
ctx := r.Context()
|
||||
|
||||
client := github.NewClient(relyingParty.OAuthConfig().Client(ctx, tokens.Token))
|
||||
|
||||
email, groups, err := GetGithubUserInfo(ctx, client, rh.c.Log)
|
||||
if email == "" || err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
callbackUI, err := OAuth2Callback(rh.c, w, r, state, email, groups) //nolint: contextcheck
|
||||
if err != nil {
|
||||
if errors.Is(err, zerr.ErrInvalidStateCookie) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if callbackUI != "" {
|
||||
http.Redirect(w, r, callbackUI, http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
}
|
||||
|
||||
// Openid CodeExchange callback.
|
||||
func (rh *RouteHandler) OpenIDCodeExchangeCallback() rp.CodeExchangeUserinfoCallback {
|
||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string,
|
||||
relyingParty rp.RelyingParty, info oidc.UserInfo,
|
||||
) {
|
||||
email := info.GetEmail()
|
||||
if email == "" {
|
||||
rh.c.Log.Error().Msg("couldn't set user record for empty email value")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var groups []string
|
||||
|
||||
val, ok := info.GetClaim("groups").([]interface{})
|
||||
if !ok {
|
||||
rh.c.Log.Info().Msgf("couldn't find any 'groups' claim for user %s", email)
|
||||
}
|
||||
|
||||
for _, group := range val {
|
||||
groups = append(groups, fmt.Sprint(group))
|
||||
}
|
||||
|
||||
callbackUI, err := OAuth2Callback(rh.c, w, r, state, email, groups)
|
||||
if err != nil {
|
||||
if errors.Is(err, zerr.ErrInvalidStateCookie) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if callbackUI != "" {
|
||||
http.Redirect(w, r, callbackUI, http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
}
|
||||
|
||||
func (rh *RouteHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
m := rh.c.Metrics.ReceiveMetrics()
|
||||
zcommon.WriteJSON(w, http.StatusOK, m)
|
||||
|
||||
+175
-4
@@ -1,26 +1,36 @@
|
||||
//go:build sync && scrub && metrics && search && lint
|
||||
// +build sync,scrub,metrics,search,lint
|
||||
//go:build sync && scrub && metrics && search && lint && apikey
|
||||
// +build sync,scrub,metrics,search,lint,apikey
|
||||
|
||||
package api_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
godigest "github.com/opencontainers/go-digest"
|
||||
ispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
"github.com/project-zot/mockoidc"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
zerr "zotregistry.io/zot/errors"
|
||||
"zotregistry.io/zot/pkg/api"
|
||||
"zotregistry.io/zot/pkg/api/config"
|
||||
"zotregistry.io/zot/pkg/api/constants"
|
||||
"zotregistry.io/zot/pkg/extensions"
|
||||
extconf "zotregistry.io/zot/pkg/extensions/config"
|
||||
"zotregistry.io/zot/pkg/meta/repodb"
|
||||
localCtx "zotregistry.io/zot/pkg/requestcontext"
|
||||
storageTypes "zotregistry.io/zot/pkg/storage/types"
|
||||
"zotregistry.io/zot/pkg/test"
|
||||
@@ -29,6 +39,8 @@ import (
|
||||
|
||||
var ErrUnexpectedError = errors.New("error: unexpected error")
|
||||
|
||||
const sessionStr = "session"
|
||||
|
||||
func TestRoutes(t *testing.T) {
|
||||
Convey("Make a new controller", t, func() {
|
||||
port := test.GetFreePort()
|
||||
@@ -36,6 +48,45 @@ func TestRoutes(t *testing.T) {
|
||||
conf := config.New()
|
||||
conf.HTTP.Port = port
|
||||
|
||||
htpasswdPath := test.MakeHtpasswdFile()
|
||||
defer os.Remove(htpasswdPath)
|
||||
mockOIDCServer, err := mockoidc.Run()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
err := mockOIDCServer.Shutdown()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
mockOIDCConfig := mockOIDCServer.Config()
|
||||
conf.HTTP.Auth = &config.AuthConfig{
|
||||
HTPasswd: config.AuthHTPasswd{
|
||||
Path: htpasswdPath,
|
||||
},
|
||||
OpenID: &config.OpenIDConfig{
|
||||
Providers: map[string]config.OpenIDProviderConfig{
|
||||
"dex": {
|
||||
ClientID: mockOIDCConfig.ClientID,
|
||||
ClientSecret: mockOIDCConfig.ClientSecret,
|
||||
KeyPath: "",
|
||||
Issuer: mockOIDCConfig.Issuer,
|
||||
Scopes: []string{"openid", "email"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
defaultVal := true
|
||||
apiKeyConfig := &extconf.APIKeyConfig{
|
||||
BaseConfig: extconf.BaseConfig{Enable: &defaultVal},
|
||||
}
|
||||
conf.Extensions = &extconf.ExtensionConfig{
|
||||
APIKey: apiKeyConfig,
|
||||
}
|
||||
|
||||
ctlr := api.NewController(conf)
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
@@ -50,6 +101,52 @@ func TestRoutes(t *testing.T) {
|
||||
// NOTE: the url or method itself doesn't matter below since we are calling the handlers directly,
|
||||
// so path routing is bypassed
|
||||
|
||||
Convey("Test GithubCodeExchangeCallback", func() {
|
||||
callback := rthdlr.GithubCodeExchangeCallback()
|
||||
ctx := context.TODO()
|
||||
|
||||
request, _ := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
||||
response := httptest.NewRecorder()
|
||||
|
||||
tokens := &oidc.Tokens{}
|
||||
relyingParty, err := rp.NewRelyingPartyOAuth(&oauth2.Config{})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
callback(response, request, tokens, "state", relyingParty)
|
||||
|
||||
resp := response.Result()
|
||||
defer resp.Body.Close()
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
Convey("Test OAuth2Callback errors", func() {
|
||||
ctx := context.TODO()
|
||||
|
||||
request, _ := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
||||
response := httptest.NewRecorder()
|
||||
|
||||
_, err := api.OAuth2Callback(ctlr, response, request, "state", "email", []string{"group"})
|
||||
So(err, ShouldEqual, zerr.ErrInvalidStateCookie)
|
||||
|
||||
session, _ := ctlr.CookieStore.Get(request, "statecookie")
|
||||
|
||||
session.Options.Secure = true
|
||||
session.Options.HttpOnly = true
|
||||
session.Options.SameSite = http.SameSiteDefaultMode
|
||||
|
||||
state := uuid.New().String()
|
||||
|
||||
session.Values["state"] = state
|
||||
|
||||
// let the session set its own id
|
||||
err = session.Save(request, response)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = api.OAuth2Callback(ctlr, response, request, "state", "email", []string{"group"})
|
||||
So(err, ShouldEqual, zerr.ErrInvalidStateCookie)
|
||||
})
|
||||
|
||||
Convey("List repositories authz error", func() {
|
||||
var invalid struct{}
|
||||
|
||||
@@ -575,7 +672,7 @@ func TestRoutes(t *testing.T) {
|
||||
},
|
||||
&mocks.MockedImageStore{
|
||||
FullBlobUploadFn: func(repo string, body io.Reader, digest godigest.Digest) (string, int64, error) {
|
||||
return "session", 0, zerr.ErrBadBlobDigest
|
||||
return sessionStr, 0, zerr.ErrBadBlobDigest
|
||||
},
|
||||
})
|
||||
So(statusCode, ShouldEqual, http.StatusInternalServerError)
|
||||
@@ -591,7 +688,7 @@ func TestRoutes(t *testing.T) {
|
||||
},
|
||||
&mocks.MockedImageStore{
|
||||
FullBlobUploadFn: func(repo string, body io.Reader, digest godigest.Digest) (string, int64, error) {
|
||||
return "session", 20, nil
|
||||
return sessionStr, 20, nil
|
||||
},
|
||||
})
|
||||
So(statusCode, ShouldEqual, http.StatusInternalServerError)
|
||||
@@ -1327,6 +1424,80 @@ func TestRoutes(t *testing.T) {
|
||||
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
||||
})
|
||||
|
||||
Convey("Test API keys", func() {
|
||||
var invalid struct{}
|
||||
|
||||
ctx := context.TODO()
|
||||
key := localCtx.GetContextKey()
|
||||
ctx = context.WithValue(ctx, key, invalid)
|
||||
|
||||
request, _ := http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader([]byte{}))
|
||||
response := httptest.NewRecorder()
|
||||
|
||||
extensions.CreateAPIKey(response, request, ctlr.RepoDB, ctlr.CookieStore, ctlr.Log)
|
||||
|
||||
resp := response.Result()
|
||||
defer resp.Body.Close()
|
||||
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: username,
|
||||
}
|
||||
|
||||
ctx = context.TODO()
|
||||
key = localCtx.GetContextKey()
|
||||
ctx = context.WithValue(ctx, key, acCtx)
|
||||
|
||||
request, _ = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader([]byte{}))
|
||||
response = httptest.NewRecorder()
|
||||
|
||||
extensions.CreateAPIKey(response, request, ctlr.RepoDB, ctlr.CookieStore, ctlr.Log)
|
||||
|
||||
resp = response.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
|
||||
|
||||
payload := extensions.APIKeyPayload{
|
||||
Label: "test",
|
||||
Scopes: []string{"test"},
|
||||
}
|
||||
reqBody, err := json.Marshal(payload)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
request, _ = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(reqBody))
|
||||
response = httptest.NewRecorder()
|
||||
|
||||
extensions.CreateAPIKey(response, request, mocks.RepoDBMock{
|
||||
AddUserAPIKeyFn: func(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error {
|
||||
return ErrUnexpectedError
|
||||
},
|
||||
}, ctlr.CookieStore, ctlr.Log)
|
||||
|
||||
resp = response.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
|
||||
|
||||
request, _ = http.NewRequestWithContext(ctx, http.MethodDelete, baseURL, bytes.NewReader([]byte{}))
|
||||
response = httptest.NewRecorder()
|
||||
|
||||
q := request.URL.Query()
|
||||
q.Add("id", "apikeyid")
|
||||
request.URL.RawQuery = q.Encode()
|
||||
|
||||
extensions.RevokeAPIKey(response, request, mocks.RepoDBMock{
|
||||
DeleteUserAPIKeyFn: func(ctx context.Context, id string) error {
|
||||
return ErrUnexpectedError
|
||||
},
|
||||
}, ctlr.CookieStore, ctlr.Log)
|
||||
|
||||
resp = response.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
|
||||
})
|
||||
|
||||
Convey("Helper functions", func() {
|
||||
testUpdateBlobUpload := func(
|
||||
query []struct{ k, v string },
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
//go:build sync && scrub && metrics && search
|
||||
// +build sync,scrub,metrics,search
|
||||
//go:build sync && scrub && metrics && search && apikey
|
||||
// +build sync,scrub,metrics,search,apikey
|
||||
|
||||
package cli_test
|
||||
|
||||
@@ -857,6 +857,67 @@ func TestServeMgmtExtension(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestServeAPIKeyExtension(t *testing.T) {
|
||||
oldArgs := os.Args
|
||||
|
||||
defer func() { os.Args = oldArgs }()
|
||||
|
||||
Convey("apikey implicitly enabled", t, func(c C) {
|
||||
content := `{
|
||||
"storage": {
|
||||
"rootDirectory": "%s"
|
||||
},
|
||||
"http": {
|
||||
"address": "127.0.0.1",
|
||||
"port": "%s"
|
||||
},
|
||||
"log": {
|
||||
"level": "debug",
|
||||
"output": "%s"
|
||||
},
|
||||
"extensions": {
|
||||
"apikey": {
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
logPath, err := runCLIWithConfig(t.TempDir(), content)
|
||||
So(err, ShouldBeNil)
|
||||
data, err := os.ReadFile(logPath)
|
||||
So(err, ShouldBeNil)
|
||||
defer os.Remove(logPath) // clean up
|
||||
So(string(data), ShouldContainSubstring, "\"APIKey\":{\"Enable\":true}")
|
||||
})
|
||||
|
||||
Convey("apikey disabled", t, func(c C) {
|
||||
content := `{
|
||||
"storage": {
|
||||
"rootDirectory": "%s"
|
||||
},
|
||||
"http": {
|
||||
"address": "127.0.0.1",
|
||||
"port": "%s"
|
||||
},
|
||||
"log": {
|
||||
"level": "debug",
|
||||
"output": "%s"
|
||||
},
|
||||
"extensions": {
|
||||
"apikey": {
|
||||
"enable": "false"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
logPath, err := runCLIWithConfig(t.TempDir(), content)
|
||||
So(err, ShouldBeNil)
|
||||
data, err := os.ReadFile(logPath)
|
||||
So(err, ShouldBeNil)
|
||||
defer os.Remove(logPath) // clean up
|
||||
So(string(data), ShouldContainSubstring, "\"APIKey\":{\"Enable\":false}")
|
||||
})
|
||||
}
|
||||
|
||||
func readLogFileAndSearchString(logPath string, stringToMatch string, timeout time.Duration) (bool, error) { //nolint:unparam,lll
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancelFunc()
|
||||
|
||||
+52
-4
@@ -361,6 +361,10 @@ func validateConfiguration(config *config.Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateOpenIDConfig(config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateSync(config); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -377,7 +381,7 @@ func validateConfiguration(config *config.Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// check authorization config, it should have basic auth enabled or ldap
|
||||
// check authorization config, it should have basic auth enabled or ldap, api keys or OpenID
|
||||
if config.HTTP.AccessControl != nil {
|
||||
// checking for anonymous policy only authorization config: no users, no policies but anonymous policy
|
||||
if err := validateAuthzPolicies(config); err != nil {
|
||||
@@ -435,11 +439,42 @@ func validateConfiguration(config *config.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateOpenIDConfig(config *config.Config) error {
|
||||
if config.HTTP.Auth != nil && config.HTTP.Auth.OpenID != nil {
|
||||
for provider, providerConfig := range config.HTTP.Auth.OpenID.Providers {
|
||||
//nolint: gocritic
|
||||
if api.IsOpenIDSupported(provider) {
|
||||
if providerConfig.ClientID == "" || providerConfig.Issuer == "" ||
|
||||
len(providerConfig.Scopes) == 0 {
|
||||
log.Error().Err(errors.ErrBadConfig).
|
||||
Msg("OpenID provider config requires clientid, issuer and scopes parameters")
|
||||
|
||||
return errors.ErrBadConfig
|
||||
}
|
||||
} else if api.IsOauth2Supported(provider) {
|
||||
if providerConfig.ClientID == "" || len(providerConfig.Scopes) == 0 {
|
||||
log.Error().Err(errors.ErrBadConfig).
|
||||
Msg("OAuth2 provider config requires clientid and scopes parameters")
|
||||
|
||||
return errors.ErrBadConfig
|
||||
}
|
||||
} else {
|
||||
log.Error().Err(errors.ErrBadConfig).
|
||||
Msg("unsupported openid/oauth2 provider")
|
||||
|
||||
return errors.ErrBadConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateAuthzPolicies(config *config.Config) error {
|
||||
if (config.HTTP.Auth == nil || (config.HTTP.Auth.HTPasswd.Path == "" && config.HTTP.Auth.LDAP == nil)) &&
|
||||
!authzContainsOnlyAnonymousPolicy(config) {
|
||||
if (config.HTTP.Auth == nil || (config.HTTP.Auth.HTPasswd.Path == "" && config.HTTP.Auth.LDAP == nil &&
|
||||
config.HTTP.Auth.OpenID == nil)) && !authzContainsOnlyAnonymousPolicy(config) {
|
||||
log.Error().Err(errors.ErrBadConfig).
|
||||
Msg("access control config requires httpasswd, ldap authentication " +
|
||||
Msg("access control config requires one of httpasswd, ldap or openid authentication " +
|
||||
"or using only 'anonymousPolicy' policies")
|
||||
|
||||
return errors.ErrBadConfig
|
||||
@@ -484,6 +519,13 @@ func applyDefaultValues(config *config.Config, viperInstance *viper.Viper) {
|
||||
// Note: In case mgmt is not empty the config.Extensions will not be nil and we will not reach here
|
||||
config.Extensions.Mgmt = &extconf.MgmtConfig{}
|
||||
}
|
||||
|
||||
_, ok = extMap["apikey"]
|
||||
if ok {
|
||||
// we found a config like `"extensions": {"mgmt:": {}}`
|
||||
// Note: In case mgmt is not empty the config.Extensions will not be nil and we will not reach here
|
||||
config.Extensions.APIKey = &extconf.APIKeyConfig{}
|
||||
}
|
||||
}
|
||||
|
||||
if config.Extensions != nil {
|
||||
@@ -550,6 +592,12 @@ func applyDefaultValues(config *config.Config, viperInstance *viper.Viper) {
|
||||
}
|
||||
}
|
||||
|
||||
if config.Extensions.APIKey != nil {
|
||||
if config.Extensions.APIKey.Enable == nil {
|
||||
config.Extensions.APIKey.Enable = &defaultVal
|
||||
}
|
||||
}
|
||||
|
||||
if config.Extensions.Scrub != nil {
|
||||
if config.Extensions.Scrub.Enable == nil {
|
||||
config.Extensions.Scrub.Enable = &defaultVal
|
||||
|
||||
@@ -952,6 +952,71 @@ func TestVerify(t *testing.T) {
|
||||
So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldPanic)
|
||||
})
|
||||
|
||||
Convey("Test verify openid config with missing parameter", t, func(c C) {
|
||||
tmpfile, err := os.CreateTemp("", "zot-test*.json")
|
||||
So(err, ShouldBeNil)
|
||||
defer os.Remove(tmpfile.Name()) // clean up
|
||||
content := []byte(`{"distSpecVersion":"1.1.0-dev","storage":{"rootDirectory":"/tmp/zot"},
|
||||
"http":{"address":"127.0.0.1","port":"8080","realm":"zot",
|
||||
"auth":{"openid":{"providers":{"dex":{"issuer":"http://127.0.0.1:5556/dex"}}}}},
|
||||
"log":{"level":"debug"}}`)
|
||||
_, err = tmpfile.Write(content)
|
||||
So(err, ShouldBeNil)
|
||||
err = tmpfile.Close()
|
||||
So(err, ShouldBeNil)
|
||||
os.Args = []string{"cli_test", "verify", tmpfile.Name()}
|
||||
So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldPanic)
|
||||
})
|
||||
|
||||
Convey("Test verify oauth2 config with missing parameter", t, func(c C) {
|
||||
tmpfile, err := os.CreateTemp("", "zot-test*.json")
|
||||
So(err, ShouldBeNil)
|
||||
defer os.Remove(tmpfile.Name()) // clean up
|
||||
content := []byte(`{"distSpecVersion":"1.1.0-dev","storage":{"rootDirectory":"/tmp/zot"},
|
||||
"http":{"address":"127.0.0.1","port":"8080","realm":"zot",
|
||||
"auth":{"openid":{"providers":{"github":{"clientid":"client_id"}}}}},
|
||||
"log":{"level":"debug"}}`)
|
||||
_, err = tmpfile.Write(content)
|
||||
So(err, ShouldBeNil)
|
||||
err = tmpfile.Close()
|
||||
So(err, ShouldBeNil)
|
||||
os.Args = []string{"cli_test", "verify", tmpfile.Name()}
|
||||
So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldPanic)
|
||||
})
|
||||
|
||||
Convey("Test verify openid config with unsupported provider", t, func(c C) {
|
||||
tmpfile, err := os.CreateTemp("", "zot-test*.json")
|
||||
So(err, ShouldBeNil)
|
||||
defer os.Remove(tmpfile.Name()) // clean up
|
||||
content := []byte(`{"distSpecVersion":"1.1.0-dev","storage":{"rootDirectory":"/tmp/zot"},
|
||||
"http":{"address":"127.0.0.1","port":"8080","realm":"zot",
|
||||
"auth":{"openid":{"providers":{"unsupported":{"issuer":"http://127.0.0.1:5556/dex"}}}}},
|
||||
"log":{"level":"debug"}}`)
|
||||
_, err = tmpfile.Write(content)
|
||||
So(err, ShouldBeNil)
|
||||
err = tmpfile.Close()
|
||||
So(err, ShouldBeNil)
|
||||
os.Args = []string{"cli_test", "verify", tmpfile.Name()}
|
||||
So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldPanic)
|
||||
})
|
||||
|
||||
Convey("Test verify openid config without apikey extension enabled", t, func(c C) {
|
||||
tmpfile, err := os.CreateTemp("", "zot-test*.json")
|
||||
So(err, ShouldBeNil)
|
||||
defer os.Remove(tmpfile.Name()) // clean up
|
||||
content := []byte(`{"distSpecVersion":"1.1.0-dev","storage":{"rootDirectory":"/tmp/zot"},
|
||||
"http":{"address":"127.0.0.1","port":"8080","realm":"zot",
|
||||
"auth":{"openid":{"providers":{"dex":{"issuer":"http://127.0.0.1:5556/dex",
|
||||
"clientid":"client_id","scopes":["openid"]}}}}},
|
||||
"log":{"level":"debug"}}`)
|
||||
_, err = tmpfile.Write(content)
|
||||
So(err, ShouldBeNil)
|
||||
err = tmpfile.Close()
|
||||
So(err, ShouldBeNil)
|
||||
os.Args = []string{"cli_test", "verify", tmpfile.Name()}
|
||||
So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldNotPanic)
|
||||
})
|
||||
|
||||
Convey("Test verify config with missing basedn key", t, func(c C) {
|
||||
tmpfile, err := os.CreateTemp("", "zot-test*.json")
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
@@ -2,14 +2,17 @@ package common
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/sessions"
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
|
||||
"zotregistry.io/zot/pkg/api/constants"
|
||||
apiErr "zotregistry.io/zot/pkg/api/errors"
|
||||
"zotregistry.io/zot/pkg/log"
|
||||
)
|
||||
|
||||
func AllowedMethods(methods ...string) []string {
|
||||
@@ -32,7 +35,8 @@ func ACHeadersHandler(allowedMethods ...string) mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
|
||||
resp.Header().Set("Access-Control-Allow-Methods", headerValue)
|
||||
resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
|
||||
resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
|
||||
resp.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if req.Method == http.MethodOptions {
|
||||
return
|
||||
@@ -43,9 +47,20 @@ func ACHeadersHandler(allowedMethods ...string) mux.MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func AuthzFail(w http.ResponseWriter, realm string, delay int) {
|
||||
func AuthzFail(w http.ResponseWriter, r *http.Request, realm string, delay int) {
|
||||
time.Sleep(time.Duration(delay) * time.Second)
|
||||
w.Header().Set("WWW-Authenticate", realm)
|
||||
|
||||
// don't send auth headers if request is coming from UI
|
||||
if r.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue {
|
||||
if realm == "" {
|
||||
realm = "Authorization Required"
|
||||
}
|
||||
|
||||
realm = "Basic realm=" + strconv.Quote(realm)
|
||||
|
||||
w.Header().Set("WWW-Authenticate", realm)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
WriteJSON(w, http.StatusForbidden, apiErr.NewErrorList(apiErr.NewError(apiErr.DENIED)))
|
||||
}
|
||||
@@ -66,3 +81,39 @@ func WriteData(w http.ResponseWriter, status int, mediaType string, data []byte)
|
||||
w.WriteHeader(status)
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
|
||||
/*
|
||||
GetAuthUserFromRequestSession returns identity
|
||||
and auth status if on the request's cookie session is a logged in user.
|
||||
*/
|
||||
func GetAuthUserFromRequestSession(cookieStore sessions.Store, request *http.Request, log log.Logger,
|
||||
) (string, bool) {
|
||||
session, err := cookieStore.Get(request, "session")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("can not decode existing session")
|
||||
// expired cookie, no need to return err
|
||||
return "", false
|
||||
}
|
||||
|
||||
// at this point we should have a session set on cookie.
|
||||
// if created in the earlier Get() call then user is not logged in with sessions.
|
||||
if session.IsNew {
|
||||
return "", false
|
||||
}
|
||||
|
||||
authenticated := session.Values["authStatus"]
|
||||
if authenticated != true {
|
||||
log.Error().Msg("can not get `user` session value")
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
identity, ok := session.Values["user"].(string)
|
||||
if !ok {
|
||||
log.Error().Msg("can not get `user` session value")
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
return identity, true
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
# `API keys`
|
||||
|
||||
zot allows authentication for REST API calls using your API key as an alternative to your password.
|
||||
|
||||
* User can create/revoke his API key.
|
||||
|
||||
* Can not be retrieved, it is shown to the user only the first time is created.
|
||||
|
||||
* An API key has the same rights as the user who generated it.
|
||||
|
||||
## API keys REST API
|
||||
|
||||
|
||||
### Create API Key
|
||||
**Description**: Create an API key for the current user.
|
||||
|
||||
**Usage**: POST /v2/_zot/ext/apikey
|
||||
|
||||
**Produces**: application/json
|
||||
|
||||
**Sample input**:
|
||||
```
|
||||
POST /api/security/apiKey
|
||||
Body: {"label": "git", "scopes": ["repo1", "repo2"]}'
|
||||
```
|
||||
|
||||
**Example cURL**
|
||||
```
|
||||
curl -u user:password -X POST http://localhost:8080/v2/_zot/ext/apikey -d '{"label": "myLabel", "scopes": ["repo1", "repo2"]}'
|
||||
```
|
||||
|
||||
**Sample output**:
|
||||
```json
|
||||
{
|
||||
"createdAt": "2023-05-05T15:39:28.420926+03:00",
|
||||
"creatorUa": "curl/7.68.0",
|
||||
"generatedBy": "manual",
|
||||
"lastUsed": "2023-05-05T15:39:28.4209282+03:00",
|
||||
"label": "git",
|
||||
"scopes": [
|
||||
"repo1",
|
||||
"repo2"
|
||||
],
|
||||
"uuid": "46a45ce7-5d92-498a-a9cb-9654b1da3da1",
|
||||
"apiKey": "zak_e77bcb9e9f634f1581756abbf9ecd269"
|
||||
}
|
||||
```
|
||||
|
||||
**Using API keys cURL**
|
||||
```
|
||||
curl -u user:zak_e77bcb9e9f634f1581756abbf9ecd269 http://localhost:8080/v2/_catalog
|
||||
```
|
||||
|
||||
|
||||
### Revoke API Key
|
||||
**Description**: Revokes one current user API key by api key UUID
|
||||
|
||||
**Usage**: DELETE /api/security/apiKey?id=$uuid
|
||||
|
||||
**Produces**: application/json
|
||||
|
||||
|
||||
**Example cURL**
|
||||
```
|
||||
curl -u user:password -X DELETE http://localhost:8080/v2/_zot/ext/apikey?id=46a45ce7-5d92-498a-a9cb-9654b1da3da1
|
||||
```
|
||||
@@ -8,6 +8,7 @@ Component | Endpoint | Description
|
||||
[`search`](search/search.md) | `/v2/_zot/ext/search` | efficient and enhanced registry search capabilities using graphQL backend
|
||||
[`mgmt`](mgmt.md) | `/v2/_zot/ext/mgmt` | config management
|
||||
[`userprefs`](userprefs.md) | `/v2/_zot/ext/userprefs` | change user preferences
|
||||
[`apikey`](README_apikey.md) | `/v2/_zot/ext/apikey` | user api keys management
|
||||
|
||||
|
||||
# References
|
||||
|
||||
@@ -19,6 +19,11 @@ type ExtensionConfig struct {
|
||||
Lint *LintConfig
|
||||
UI *UIConfig
|
||||
Mgmt *MgmtConfig
|
||||
APIKey *APIKeyConfig
|
||||
}
|
||||
|
||||
type APIKeyConfig struct {
|
||||
BaseConfig `mapstructure:",squash"`
|
||||
}
|
||||
|
||||
type MgmtConfig struct {
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
//go:build apikey
|
||||
// +build apikey
|
||||
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
guuid "github.com/gofrs/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/sessions"
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
godigest "github.com/opencontainers/go-digest"
|
||||
|
||||
"zotregistry.io/zot/pkg/api/config"
|
||||
"zotregistry.io/zot/pkg/api/constants"
|
||||
zcommon "zotregistry.io/zot/pkg/common"
|
||||
"zotregistry.io/zot/pkg/log"
|
||||
"zotregistry.io/zot/pkg/meta/repodb"
|
||||
)
|
||||
|
||||
func SetupAPIKeyRoutes(config *config.Config, router *mux.Router, repoDB repodb.RepoDB,
|
||||
cookieStore sessions.Store, log log.Logger,
|
||||
) {
|
||||
if config.Extensions.APIKey != nil && *config.Extensions.APIKey.Enable {
|
||||
log.Info().Msg("setting up api key routes")
|
||||
|
||||
allowedMethods := zcommon.AllowedMethods(http.MethodPost, http.MethodDelete)
|
||||
|
||||
apiKeyRouter := router.PathPrefix(constants.ExtAPIKey).Subrouter()
|
||||
apiKeyRouter.Use(zcommon.ACHeadersHandler(allowedMethods...))
|
||||
apiKeyRouter.Use(zcommon.AddExtensionSecurityHeaders())
|
||||
apiKeyRouter.Methods(allowedMethods...).Handler(HandleAPIKeyRequest(repoDB, cookieStore, log))
|
||||
}
|
||||
}
|
||||
|
||||
type APIKeyPayload struct { //nolint:revive
|
||||
Label string `json:"label"`
|
||||
Scopes []string `json:"scopes"`
|
||||
}
|
||||
|
||||
func HandleAPIKeyRequest(repoDB repodb.RepoDB, cookieStore sessions.Store,
|
||||
log log.Logger,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
|
||||
switch req.Method {
|
||||
case http.MethodPost:
|
||||
CreateAPIKey(resp, req, repoDB, cookieStore, log) //nolint:contextcheck
|
||||
|
||||
return
|
||||
case http.MethodDelete:
|
||||
RevokeAPIKey(resp, req, repoDB, cookieStore, log) //nolint:contextcheck
|
||||
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// CreateAPIKey godoc
|
||||
// @Summary Create an API key for the current user
|
||||
// @Description Can create an api key for a logged in user, based on the provided label and scopes.
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 201 {string} string "created"
|
||||
// @Failure 401 {string} string "unauthorized"
|
||||
// @Failure 500 {string} string "internal server error"
|
||||
// @Router /v2/_zot/ext/apikey [post].
|
||||
func CreateAPIKey(resp http.ResponseWriter, req *http.Request, repoDB repodb.RepoDB,
|
||||
cookieStore sessions.Store, log log.Logger,
|
||||
) {
|
||||
var payload APIKeyPayload
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
log.Error().Msg("unable to read request body")
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = json.Unmarshal(body, &payload)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("unable to unmarshal body")
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyBase, err := guuid.NewV4()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("unable to generate uuid")
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := strings.ReplaceAll(apiKeyBase.String(), "-", "")
|
||||
|
||||
hashedAPIKey := hashUUID(apiKey)
|
||||
|
||||
// will be used for identifying a specific api key
|
||||
apiKeyID, err := guuid.NewV4()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("unable to generate uuid")
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyDetails := &repodb.APIKeyDetails{
|
||||
CreatedAt: time.Now(),
|
||||
LastUsed: time.Now(),
|
||||
CreatorUA: req.UserAgent(),
|
||||
GeneratedBy: "manual",
|
||||
Label: payload.Label,
|
||||
Scopes: payload.Scopes,
|
||||
UUID: apiKeyID.String(),
|
||||
}
|
||||
|
||||
err = repoDB.AddUserAPIKey(req.Context(), hashedAPIKey, apiKeyDetails)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error storing API key")
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyResponse := struct {
|
||||
repodb.APIKeyDetails
|
||||
APIKey string `json:"apiKey"`
|
||||
}{
|
||||
APIKey: fmt.Sprintf("%s%s", constants.APIKeysPrefix, apiKey),
|
||||
APIKeyDetails: *apiKeyDetails,
|
||||
}
|
||||
|
||||
json := jsoniter.ConfigCompatibleWithStandardLibrary
|
||||
|
||||
data, err := json.Marshal(apiKeyResponse)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("unable to marshal api key response")
|
||||
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp.Header().Set("Content-Type", constants.DefaultMediaType)
|
||||
resp.WriteHeader(http.StatusCreated)
|
||||
_, _ = resp.Write(data)
|
||||
}
|
||||
|
||||
// RevokeAPIKey godoc
|
||||
// @Summary Revokes one current user API key
|
||||
// @Description Revokes one current user API key based on given key ID
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "api token id (UUID)"
|
||||
// @Success 200 {string} string "ok"
|
||||
// @Failure 500 {string} string "internal server error"
|
||||
// @Failure 401 {string} string "unauthorized"
|
||||
// @Failure 400 {string} string "bad request"
|
||||
// @Router /v2/_zot/ext/apikey?id=UUID [delete].
|
||||
func RevokeAPIKey(resp http.ResponseWriter, req *http.Request, repoDB repodb.RepoDB,
|
||||
cookieStore sessions.Store, log log.Logger,
|
||||
) {
|
||||
ids, ok := req.URL.Query()["id"]
|
||||
if !ok || len(ids) != 1 {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
keyID := ids[0]
|
||||
|
||||
err := repoDB.DeleteUserAPIKey(req.Context(), keyID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("keyID", keyID).Msg("error deleting API key")
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func hashUUID(uuid string) string {
|
||||
digester := sha256.New()
|
||||
digester.Write([]byte(uuid))
|
||||
|
||||
return godigest.NewDigestFromEncoded(godigest.SHA256, fmt.Sprintf("%x", digester.Sum(nil))).Encoded()
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
//go:build !apikey
|
||||
// +build !apikey
|
||||
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/sessions"
|
||||
|
||||
"zotregistry.io/zot/pkg/api/config"
|
||||
"zotregistry.io/zot/pkg/log"
|
||||
"zotregistry.io/zot/pkg/meta/repodb"
|
||||
)
|
||||
|
||||
func SetupAPIKeyRoutes(config *config.Config, router *mux.Router, repoDB repodb.RepoDB,
|
||||
cookieStore sessions.Store, log log.Logger,
|
||||
) {
|
||||
log.Warn().Msg("skipping setting up API key routes because given zot binary doesn't include this feature," +
|
||||
"please build a binary that does so")
|
||||
}
|
||||
@@ -0,0 +1,531 @@
|
||||
//go:build apikey
|
||||
// +build apikey
|
||||
|
||||
package extensions_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/project-zot/mockoidc"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"gopkg.in/resty.v1"
|
||||
|
||||
"zotregistry.io/zot/pkg/api"
|
||||
"zotregistry.io/zot/pkg/api/config"
|
||||
"zotregistry.io/zot/pkg/api/constants"
|
||||
"zotregistry.io/zot/pkg/extensions"
|
||||
extconf "zotregistry.io/zot/pkg/extensions/config"
|
||||
"zotregistry.io/zot/pkg/meta/repodb"
|
||||
localCtx "zotregistry.io/zot/pkg/requestcontext"
|
||||
"zotregistry.io/zot/pkg/test"
|
||||
"zotregistry.io/zot/pkg/test/mocks"
|
||||
)
|
||||
|
||||
type (
|
||||
apiKeyResponse struct {
|
||||
repodb.APIKeyDetails
|
||||
APIKey string `json:"apiKey"`
|
||||
}
|
||||
)
|
||||
|
||||
var ErrUnexpectedError = errors.New("unexpected err")
|
||||
|
||||
func TestAPIKeys(t *testing.T) {
|
||||
Convey("Make a new controller", t, func() {
|
||||
port := test.GetFreePort()
|
||||
baseURL := test.GetBaseURL(port)
|
||||
|
||||
conf := config.New()
|
||||
conf.HTTP.Port = port
|
||||
|
||||
htpasswdPath := test.MakeHtpasswdFile()
|
||||
defer os.Remove(htpasswdPath)
|
||||
|
||||
mockOIDCServer, err := test.MockOIDCRun()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := mockOIDCServer.Shutdown()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
mockOIDCConfig := mockOIDCServer.Config()
|
||||
conf.HTTP.Auth = &config.AuthConfig{
|
||||
HTPasswd: config.AuthHTPasswd{
|
||||
Path: htpasswdPath,
|
||||
},
|
||||
OpenID: &config.OpenIDConfig{
|
||||
Providers: map[string]config.OpenIDProviderConfig{
|
||||
"dex": {
|
||||
ClientID: mockOIDCConfig.ClientID,
|
||||
ClientSecret: mockOIDCConfig.ClientSecret,
|
||||
KeyPath: "",
|
||||
Issuer: mockOIDCConfig.Issuer,
|
||||
Scopes: []string{"openid", "email", "groups"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
conf.HTTP.AccessControl = &config.AccessControlConfig{}
|
||||
|
||||
defaultVal := true
|
||||
apiKeyConfig := &extconf.APIKeyConfig{
|
||||
BaseConfig: extconf.BaseConfig{Enable: &defaultVal},
|
||||
}
|
||||
|
||||
mgmtConfg := &extconf.MgmtConfig{
|
||||
BaseConfig: extconf.BaseConfig{Enable: &defaultVal},
|
||||
}
|
||||
|
||||
conf.Extensions = &extconf.ExtensionConfig{
|
||||
APIKey: apiKeyConfig,
|
||||
Mgmt: mgmtConfg,
|
||||
}
|
||||
|
||||
ctlr := api.NewController(conf)
|
||||
dir := t.TempDir()
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = dir
|
||||
|
||||
cm := test.NewControllerManager(ctlr)
|
||||
|
||||
cm.StartServer()
|
||||
defer cm.StopServer()
|
||||
test.WaitTillServerReady(baseURL)
|
||||
|
||||
payload := extensions.APIKeyPayload{
|
||||
Label: "test",
|
||||
Scopes: []string{"test"},
|
||||
}
|
||||
reqBody, err := json.Marshal(payload)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
Convey("API key retrieved with basic auth", func() {
|
||||
// call endpoint with session ( added to client after previous request)
|
||||
resp, err := resty.R().
|
||||
SetBody(reqBody).
|
||||
SetBasicAuth("test", "test").
|
||||
Post(baseURL + constants.FullAPIKeyPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
|
||||
|
||||
user := mockoidc.DefaultUser()
|
||||
|
||||
// get API key and email from apikey route response
|
||||
var apiKeyResponse apiKeyResponse
|
||||
err = json.Unmarshal(resp.Body(), &apiKeyResponse)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
email := user.Email
|
||||
So(email, ShouldNotBeEmpty)
|
||||
|
||||
resp, err = resty.R().
|
||||
SetBasicAuth("test", apiKeyResponse.APIKey).
|
||||
Get(baseURL + "/v2/_catalog")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
// add another one
|
||||
resp, err = resty.R().
|
||||
SetBody(reqBody).
|
||||
SetBasicAuth("test", "test").
|
||||
Post(baseURL + constants.FullAPIKeyPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
|
||||
|
||||
err = json.Unmarshal(resp.Body(), &apiKeyResponse)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
resp, err = resty.R().
|
||||
SetBasicAuth("test", apiKeyResponse.APIKey).
|
||||
Get(baseURL + "/v2/_catalog")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
})
|
||||
|
||||
Convey("API key retrieved with openID", func() {
|
||||
client := resty.New()
|
||||
client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
|
||||
|
||||
// first login user
|
||||
resp, err := client.R().
|
||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
||||
SetQueryParam("provider", "dex").
|
||||
Get(baseURL + constants.LoginPath)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
|
||||
cookies := resp.Cookies()
|
||||
|
||||
// call endpoint without session
|
||||
resp, err = client.R().
|
||||
SetBody(reqBody).
|
||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
||||
Post(baseURL + constants.FullAPIKeyPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
|
||||
|
||||
client.SetCookies(cookies)
|
||||
|
||||
// call endpoint with session ( added to client after previous request)
|
||||
resp, err = client.R().
|
||||
SetBody(reqBody).
|
||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
||||
Post(baseURL + constants.FullAPIKeyPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
|
||||
|
||||
user := mockoidc.DefaultUser()
|
||||
|
||||
// get API key and email from apikey route response
|
||||
var apiKeyResponse apiKeyResponse
|
||||
err = json.Unmarshal(resp.Body(), &apiKeyResponse)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
email := user.Email
|
||||
So(email, ShouldNotBeEmpty)
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, apiKeyResponse.APIKey).
|
||||
Get(baseURL + "/v2/_catalog")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
// trigger errors
|
||||
ctlr.RepoDB = mocks.RepoDBMock{
|
||||
GetUserAPIKeyInfoFn: func(hashedKey string) (string, error) {
|
||||
return "", ErrUnexpectedError
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, apiKeyResponse.APIKey).
|
||||
Get(baseURL + "/v2/_catalog")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError)
|
||||
|
||||
ctlr.RepoDB = mocks.RepoDBMock{
|
||||
GetUserAPIKeyInfoFn: func(hashedKey string) (string, error) {
|
||||
return user.Email, nil
|
||||
},
|
||||
GetUserGroupsFn: func(ctx context.Context) ([]string, error) {
|
||||
return []string{}, ErrUnexpectedError
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, apiKeyResponse.APIKey).
|
||||
Get(baseURL + "/v2/_catalog")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError)
|
||||
|
||||
ctlr.RepoDB = mocks.RepoDBMock{
|
||||
GetUserAPIKeyInfoFn: func(hashedKey string) (string, error) {
|
||||
return user.Email, nil
|
||||
},
|
||||
UpdateUserAPIKeyLastUsedFn: func(ctx context.Context, hashedKey string) error {
|
||||
return ErrUnexpectedError
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, apiKeyResponse.APIKey).
|
||||
Get(baseURL + "/v2/_catalog")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError)
|
||||
|
||||
client = resty.New()
|
||||
|
||||
// call endpoint without session
|
||||
resp, err = client.R().
|
||||
SetBody(reqBody).
|
||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
||||
Post(baseURL + constants.FullAPIKeyPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
Convey("Login with openid and create API key", func() {
|
||||
client := resty.New()
|
||||
|
||||
// mgmt should work both unauthenticated and authenticated
|
||||
resp, err := client.R().
|
||||
Get(baseURL + constants.FullMgmtPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
|
||||
// first login user
|
||||
resp, err = client.R().
|
||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
||||
SetQueryParam("provider", "dex").
|
||||
Get(baseURL + constants.LoginPath)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
|
||||
|
||||
client.SetCookies(resp.Cookies())
|
||||
|
||||
// call endpoint with session ( added to client after previous request)
|
||||
resp, err = client.R().
|
||||
SetBody(reqBody).
|
||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
||||
Post(baseURL + constants.FullAPIKeyPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
|
||||
|
||||
var apiKeyResponse apiKeyResponse
|
||||
err = json.Unmarshal(resp.Body(), &apiKeyResponse)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
user := mockoidc.DefaultUser()
|
||||
email := user.Email
|
||||
So(email, ShouldNotBeEmpty)
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, apiKeyResponse.APIKey).
|
||||
Get(baseURL + "/v2/_catalog")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
// auth with API key
|
||||
// we need new client without session cookie set
|
||||
client = resty.New()
|
||||
client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, apiKeyResponse.APIKey).
|
||||
Get(baseURL + "/v2/_catalog")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, apiKeyResponse.APIKey).
|
||||
Get(baseURL + constants.FullMgmtPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
// invalid api keys
|
||||
resp, err = client.R().
|
||||
SetBasicAuth("invalidEmail", apiKeyResponse.APIKey).
|
||||
Get(baseURL + constants.FullMgmtPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, "noprefixAPIKey").
|
||||
Get(baseURL + "/v2/_catalog")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, "zak_notworkingAPIKey").
|
||||
Get(baseURL + "/v2/_catalog")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
|
||||
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: email,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err = ctlr.RepoDB.DeleteUserData(ctx)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, apiKeyResponse.APIKey).
|
||||
Get(baseURL + constants.FullMgmtPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError)
|
||||
|
||||
client = resty.New()
|
||||
client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
|
||||
|
||||
// without creds should work
|
||||
resp, err = client.R().
|
||||
Get(baseURL + constants.FullMgmtPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
// login again
|
||||
resp, err = client.R().
|
||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
||||
SetQueryParam("provider", "dex").
|
||||
Get(baseURL + constants.LoginPath)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
|
||||
|
||||
client.SetCookies(resp.Cookies())
|
||||
|
||||
resp, err = client.R().
|
||||
SetBody(reqBody).
|
||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
||||
Post(baseURL + constants.FullAPIKeyPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
|
||||
|
||||
err = json.Unmarshal(resp.Body(), &apiKeyResponse)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// should work with session
|
||||
resp, err = client.R().
|
||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
||||
Get(baseURL + constants.FullMgmtPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
// should work with api key
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, apiKeyResponse.APIKey).
|
||||
Get(baseURL + constants.FullMgmtPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, apiKeyResponse.APIKey).
|
||||
Get(baseURL + "/v2/_catalog")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
err = json.Unmarshal(resp.Body(), &apiKeyResponse)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// delete api key
|
||||
resp, err = client.R().
|
||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
||||
SetQueryParam("id", apiKeyResponse.UUID).
|
||||
Delete(baseURL + constants.FullAPIKeyPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
resp, err = client.R().
|
||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
||||
Delete(baseURL + constants.FullAPIKeyPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusBadRequest)
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth(email, apiKeyResponse.APIKey).
|
||||
Get(baseURL + "/v2/_catalog")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
|
||||
|
||||
resp, err = client.R().
|
||||
SetBasicAuth("test", "test").
|
||||
SetQueryParam("id", apiKeyResponse.UUID).
|
||||
Delete(baseURL + constants.FullAPIKeyPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
// unsupported method
|
||||
resp, err = client.R().
|
||||
Put(baseURL + constants.FullAPIKeyPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusMethodNotAllowed)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAPIKeysOpenDBError(t *testing.T) {
|
||||
Convey("Test API keys - unable to create database", t, func() {
|
||||
conf := config.New()
|
||||
htpasswdPath := test.MakeHtpasswdFile()
|
||||
defer os.Remove(htpasswdPath)
|
||||
|
||||
mockOIDCServer, err := test.MockOIDCRun()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := mockOIDCServer.Shutdown()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
mockOIDCConfig := mockOIDCServer.Config()
|
||||
conf.HTTP.Auth = &config.AuthConfig{
|
||||
HTPasswd: config.AuthHTPasswd{
|
||||
Path: htpasswdPath,
|
||||
},
|
||||
|
||||
OpenID: &config.OpenIDConfig{
|
||||
Providers: map[string]config.OpenIDProviderConfig{
|
||||
"dex": {
|
||||
ClientID: mockOIDCConfig.ClientID,
|
||||
ClientSecret: mockOIDCConfig.ClientSecret,
|
||||
KeyPath: "",
|
||||
Issuer: mockOIDCConfig.Issuer,
|
||||
Scopes: []string{"openid", "email"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
defaultVal := true
|
||||
apiKeyConfig := &extconf.APIKeyConfig{
|
||||
BaseConfig: extconf.BaseConfig{Enable: &defaultVal},
|
||||
}
|
||||
conf.Extensions = &extconf.ExtensionConfig{
|
||||
APIKey: apiKeyConfig,
|
||||
}
|
||||
|
||||
ctlr := api.NewController(conf)
|
||||
dir := t.TempDir()
|
||||
|
||||
err = os.Chmod(dir, 0o000)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = dir
|
||||
cm := test.NewControllerManager(ctlr)
|
||||
|
||||
So(func() {
|
||||
cm.StartServer()
|
||||
}, ShouldPanic)
|
||||
})
|
||||
}
|
||||
@@ -36,12 +36,19 @@ type BearerConfig struct {
|
||||
Service string `json:"service,omitempty"`
|
||||
}
|
||||
|
||||
type OpenIDProviderConfig struct{}
|
||||
|
||||
type OpenIDConfig struct {
|
||||
Providers map[string]OpenIDProviderConfig `json:"providers,omitempty" mapstructure:"providers"`
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
HTPasswd *HTPasswd `json:"htpasswd,omitempty" mapstructure:"htpasswd"`
|
||||
Bearer *BearerConfig `json:"bearer,omitempty" mapstructure:"bearer"`
|
||||
LDAP *struct {
|
||||
Address string `json:"address,omitempty" mapstructure:"address"`
|
||||
} `json:"ldap,omitempty" mapstructure:"ldap"`
|
||||
OpenID *OpenIDConfig `json:"openid,omitempty" mapstructure:"openid"`
|
||||
}
|
||||
|
||||
type StrippedConfig struct {
|
||||
@@ -60,8 +67,10 @@ func (auth Auth) MarshalJSON() ([]byte, error) {
|
||||
type localAuth Auth
|
||||
|
||||
if auth.Bearer == nil && auth.LDAP == nil &&
|
||||
auth.HTPasswd.Path == "" {
|
||||
auth.HTPasswd.Path == "" &&
|
||||
(auth.OpenID == nil || len(auth.OpenID.Providers) == 0) {
|
||||
auth.HTPasswd = nil
|
||||
auth.OpenID = nil
|
||||
|
||||
return json.Marshal((localAuth)(auth))
|
||||
}
|
||||
@@ -72,6 +81,10 @@ func (auth Auth) MarshalJSON() ([]byte, error) {
|
||||
auth.HTPasswd.Path = ""
|
||||
}
|
||||
|
||||
if auth.OpenID != nil && len(auth.OpenID.Providers) == 0 {
|
||||
auth.OpenID = nil
|
||||
}
|
||||
|
||||
auth.LDAP = nil
|
||||
|
||||
return json.Marshal((localAuth)(auth))
|
||||
|
||||
@@ -39,6 +39,7 @@ func SetupUserPreferencesRoutes(config *config.Config, router *mux.Router, store
|
||||
userprefsRouter := router.PathPrefix(constants.ExtUserPreferences).Subrouter()
|
||||
userprefsRouter.Use(zcommon.ACHeadersHandler(allowedMethods...))
|
||||
userprefsRouter.Use(zcommon.AddExtensionSecurityHeaders())
|
||||
|
||||
userprefsRouter.HandleFunc("", HandleUserPrefs(repoDB, log)).Methods(allowedMethods...)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
//go:build sync || metrics || mgmt
|
||||
// +build sync metrics mgmt
|
||||
//go:build sync || metrics || mgmt || apikey
|
||||
// +build sync metrics mgmt apikey
|
||||
|
||||
package extensions_test
|
||||
|
||||
@@ -128,6 +128,20 @@ func TestMgmtExtension(t *testing.T) {
|
||||
|
||||
defaultValue := true
|
||||
|
||||
mockOIDCServer, err := test.MockOIDCRun()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := mockOIDCServer.Shutdown()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
mockOIDCConfig := mockOIDCServer.Config()
|
||||
|
||||
Convey("Verify mgmt route enabled with htpasswd", t, func() {
|
||||
htpasswdPath := test.MakeHtpasswdFile()
|
||||
conf.HTTP.Auth.HTPasswd.Path = htpasswdPath
|
||||
@@ -145,7 +159,7 @@ func TestMgmtExtension(t *testing.T) {
|
||||
ctlr := api.NewController(conf)
|
||||
|
||||
subPaths := make(map[string]config.StorageConfig)
|
||||
subPaths["/a"] = config.StorageConfig{}
|
||||
subPaths["/a"] = config.StorageConfig{RootDirectory: t.TempDir()}
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = globalDir
|
||||
ctlr.Config.Storage.SubPaths = subPaths
|
||||
@@ -158,6 +172,13 @@ func TestMgmtExtension(t *testing.T) {
|
||||
|
||||
So(string(data), ShouldContainSubstring, "setting up mgmt routes")
|
||||
|
||||
Convey("unsupported http method call", func() {
|
||||
// without credentials
|
||||
resp, err := resty.R().Patch(baseURL + constants.FullMgmtPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusMethodNotAllowed)
|
||||
})
|
||||
|
||||
// without credentials
|
||||
resp, err := resty.R().Get(baseURL + constants.FullMgmtPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
@@ -210,9 +231,9 @@ func TestMgmtExtension(t *testing.T) {
|
||||
ctlr := api.NewController(conf)
|
||||
|
||||
subPaths := make(map[string]config.StorageConfig)
|
||||
subPaths["/a"] = config.StorageConfig{}
|
||||
subPaths["/a"] = config.StorageConfig{RootDirectory: t.TempDir()}
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = globalDir
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
ctlr.Config.Storage.SubPaths = subPaths
|
||||
|
||||
ctlrManager := test.NewControllerManager(ctlr)
|
||||
@@ -259,9 +280,9 @@ func TestMgmtExtension(t *testing.T) {
|
||||
ctlr := api.NewController(conf)
|
||||
|
||||
subPaths := make(map[string]config.StorageConfig)
|
||||
subPaths["/a"] = config.StorageConfig{}
|
||||
subPaths["/a"] = config.StorageConfig{RootDirectory: t.TempDir()}
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = globalDir
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
ctlr.Config.Storage.SubPaths = subPaths
|
||||
|
||||
ctlrManager := test.NewControllerManager(ctlr)
|
||||
@@ -325,11 +346,7 @@ func TestMgmtExtension(t *testing.T) {
|
||||
|
||||
ctlr := api.NewController(conf)
|
||||
|
||||
subPaths := make(map[string]config.StorageConfig)
|
||||
subPaths["/a"] = config.StorageConfig{}
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = globalDir
|
||||
ctlr.Config.Storage.SubPaths = subPaths
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
|
||||
ctlrManager := test.NewControllerManager(ctlr)
|
||||
ctlrManager.StartAndWait(port)
|
||||
@@ -396,9 +413,9 @@ func TestMgmtExtension(t *testing.T) {
|
||||
ctlr := api.NewController(conf)
|
||||
|
||||
subPaths := make(map[string]config.StorageConfig)
|
||||
subPaths["/a"] = config.StorageConfig{}
|
||||
subPaths["/a"] = config.StorageConfig{RootDirectory: t.TempDir()}
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = globalDir
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
ctlr.Config.Storage.SubPaths = subPaths
|
||||
|
||||
ctlrManager := test.NewControllerManager(ctlr)
|
||||
@@ -445,11 +462,7 @@ func TestMgmtExtension(t *testing.T) {
|
||||
|
||||
ctlr := api.NewController(conf)
|
||||
|
||||
subPaths := make(map[string]config.StorageConfig)
|
||||
subPaths["/a"] = config.StorageConfig{}
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = globalDir
|
||||
ctlr.Config.Storage.SubPaths = subPaths
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
|
||||
ctlrManager := test.NewControllerManager(ctlr)
|
||||
ctlrManager.StartAndWait(port)
|
||||
@@ -474,6 +487,110 @@ func TestMgmtExtension(t *testing.T) {
|
||||
So(mgmtResp.HTTP.Auth.Bearer.Service, ShouldEqual, "service")
|
||||
})
|
||||
|
||||
Convey("Verify mgmt route enabled with openID", t, func() {
|
||||
conf.HTTP.Auth.HTPasswd.Path = ""
|
||||
conf.HTTP.Auth.LDAP = nil
|
||||
conf.HTTP.Auth.Bearer = nil
|
||||
|
||||
openIDProviders := make(map[string]config.OpenIDProviderConfig)
|
||||
openIDProviders["dex"] = config.OpenIDProviderConfig{
|
||||
ClientID: mockOIDCConfig.ClientID,
|
||||
ClientSecret: mockOIDCConfig.ClientSecret,
|
||||
Issuer: mockOIDCConfig.Issuer,
|
||||
}
|
||||
|
||||
conf.HTTP.Auth.OpenID = &config.OpenIDConfig{
|
||||
Providers: openIDProviders,
|
||||
}
|
||||
|
||||
conf.Extensions = &extconf.ExtensionConfig{}
|
||||
conf.Extensions.Mgmt = &extconf.MgmtConfig{
|
||||
BaseConfig: extconf.BaseConfig{
|
||||
Enable: &defaultValue,
|
||||
},
|
||||
}
|
||||
|
||||
conf.Log.Output = logFile.Name()
|
||||
defer os.Remove(logFile.Name()) // cleanup
|
||||
|
||||
ctlr := api.NewController(conf)
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
|
||||
ctlrManager := test.NewControllerManager(ctlr)
|
||||
ctlrManager.StartAndWait(port)
|
||||
defer ctlrManager.StopServer()
|
||||
|
||||
data, _ := os.ReadFile(logFile.Name())
|
||||
|
||||
So(string(data), ShouldContainSubstring, "setting up mgmt routes")
|
||||
|
||||
// without credentials
|
||||
resp, err := resty.R().Get(baseURL + constants.FullMgmtPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
mgmtResp := extensions.StrippedConfig{}
|
||||
err = json.Unmarshal(resp.Body(), &mgmtResp)
|
||||
t.Logf("resp: %v", mgmtResp.HTTP.Auth.OpenID)
|
||||
So(err, ShouldBeNil)
|
||||
So(mgmtResp.HTTP.Auth.HTPasswd, ShouldBeNil)
|
||||
So(mgmtResp.HTTP.Auth.LDAP, ShouldBeNil)
|
||||
So(mgmtResp.HTTP.Auth.Bearer, ShouldBeNil)
|
||||
So(mgmtResp.HTTP.Auth.OpenID, ShouldNotBeNil)
|
||||
So(mgmtResp.HTTP.Auth.OpenID.Providers, ShouldNotBeEmpty)
|
||||
})
|
||||
|
||||
Convey("Verify mgmt route enabled with empty openID provider list", t, func() {
|
||||
htpasswdPath := test.MakeHtpasswdFile()
|
||||
|
||||
conf.HTTP.Auth.HTPasswd.Path = htpasswdPath
|
||||
conf.HTTP.Auth.LDAP = nil
|
||||
conf.HTTP.Auth.Bearer = nil
|
||||
|
||||
openIDProviders := make(map[string]config.OpenIDProviderConfig)
|
||||
|
||||
conf.HTTP.Auth.OpenID = &config.OpenIDConfig{
|
||||
Providers: openIDProviders,
|
||||
}
|
||||
|
||||
conf.Extensions = &extconf.ExtensionConfig{}
|
||||
conf.Extensions.Mgmt = &extconf.MgmtConfig{
|
||||
BaseConfig: extconf.BaseConfig{
|
||||
Enable: &defaultValue,
|
||||
},
|
||||
}
|
||||
|
||||
conf.Log.Output = logFile.Name()
|
||||
defer os.Remove(logFile.Name()) // cleanup
|
||||
|
||||
ctlr := api.NewController(conf)
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
|
||||
ctlrManager := test.NewControllerManager(ctlr)
|
||||
ctlrManager.StartAndWait(port)
|
||||
defer ctlrManager.StopServer()
|
||||
|
||||
data, _ := os.ReadFile(logFile.Name())
|
||||
|
||||
So(string(data), ShouldContainSubstring, "setting up mgmt routes")
|
||||
|
||||
// without credentials
|
||||
resp, err := resty.R().Get(baseURL + constants.FullMgmtPrefix)
|
||||
So(err, ShouldBeNil)
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||
|
||||
mgmtResp := extensions.StrippedConfig{}
|
||||
err = json.Unmarshal(resp.Body(), &mgmtResp)
|
||||
t.Logf("resp: %v", mgmtResp.HTTP.Auth.OpenID)
|
||||
So(err, ShouldBeNil)
|
||||
So(mgmtResp.HTTP.Auth.HTPasswd, ShouldNotBeNil)
|
||||
So(mgmtResp.HTTP.Auth.LDAP, ShouldBeNil)
|
||||
So(mgmtResp.HTTP.Auth.Bearer, ShouldBeNil)
|
||||
So(mgmtResp.HTTP.Auth.OpenID, ShouldBeNil)
|
||||
})
|
||||
|
||||
Convey("Verify mgmt route enabled without any auth", t, func() {
|
||||
globalDir := t.TempDir()
|
||||
conf := config.New()
|
||||
@@ -499,11 +616,7 @@ func TestMgmtExtension(t *testing.T) {
|
||||
|
||||
ctlr := api.NewController(conf)
|
||||
|
||||
subPaths := make(map[string]config.StorageConfig)
|
||||
subPaths["/a"] = config.StorageConfig{}
|
||||
|
||||
ctlr.Config.Storage.RootDirectory = globalDir
|
||||
ctlr.Config.Storage.SubPaths = subPaths
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
|
||||
ctlrManager := test.NewControllerManager(ctlr)
|
||||
ctlrManager.StartAndWait(port)
|
||||
@@ -856,3 +969,32 @@ func TestAllowedMethodsHeaderMgmt(t *testing.T) {
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAllowedMethodsHeaderAPIKey(t *testing.T) {
|
||||
defaultVal := true
|
||||
|
||||
Convey("Test http options response", t, func() {
|
||||
conf := config.New()
|
||||
port := test.GetFreePort()
|
||||
conf.HTTP.Port = port
|
||||
conf.Extensions = &extconf.ExtensionConfig{
|
||||
APIKey: &extconf.APIKeyConfig{
|
||||
BaseConfig: extconf.BaseConfig{Enable: &defaultVal},
|
||||
},
|
||||
}
|
||||
baseURL := test.GetBaseURL(port)
|
||||
|
||||
ctlr := api.NewController(conf)
|
||||
ctlr.Config.Storage.RootDirectory = t.TempDir()
|
||||
|
||||
ctrlManager := test.NewControllerManager(ctlr)
|
||||
|
||||
ctrlManager.StartAndWait(port)
|
||||
defer ctrlManager.StopServer()
|
||||
|
||||
resp, _ := resty.R().Options(baseURL + constants.FullAPIKeyPrefix)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.Header().Get("Access-Control-Allow-Methods"), ShouldResemble, "POST,DELETE,OPTIONS")
|
||||
So(resp.StatusCode(), ShouldEqual, http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,6 +7,5 @@ const (
|
||||
RepoMetadataBucket = "RepoMetadata"
|
||||
UserDataBucket = "UserData"
|
||||
VersionBucket = "Version"
|
||||
StarredReposKey = "StarredReposKey"
|
||||
BookmarkedReposKey = "BookmarkedReposKey"
|
||||
UserAPIKeysBucket = "UserAPIKeys"
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
type DBDriverParameters struct {
|
||||
Endpoint, Region, RepoMetaTablename, ManifestDataTablename, IndexDataTablename,
|
||||
VersionTablename, UserDataTablename string
|
||||
UserDataTablename, APIKeyTablename, VersionTablename string
|
||||
}
|
||||
|
||||
func GetDynamoClient(params DBDriverParameters) (*dynamodb.Client, error) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package bolt
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -60,6 +61,11 @@ func NewBoltDBWrapper(boltDB *bbolt.DB, log log.Logger) (*DBWrapper, error) {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = transaction.CreateBucketIfNotExists([]byte(bolt.UserAPIKeysBucket))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1680,39 +1686,26 @@ func (bdw *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) (repodb.T
|
||||
var res repodb.ToggleState
|
||||
|
||||
if err := bdw.DB.Update(func(tx *bbolt.Tx) error { //nolint:varnamelen
|
||||
userdb := tx.Bucket([]byte(bolt.UserDataBucket))
|
||||
userBucket, err := userdb.CreateBucketIfNotExists([]byte(userid))
|
||||
if err != nil {
|
||||
// this is a serious failure
|
||||
return zerr.ErrUnableToCreateUserBucket
|
||||
var userData repodb.UserData
|
||||
|
||||
err := bdw.getUserData(userid, tx, &userData)
|
||||
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
mdata := userBucket.Get([]byte(bolt.StarredReposKey))
|
||||
unpacked := []string{}
|
||||
if mdata != nil {
|
||||
if err = json.Unmarshal(mdata, &unpacked); err != nil {
|
||||
return zerr.ErrInvalidOldUserStarredRepos
|
||||
}
|
||||
}
|
||||
|
||||
isRepoStarred := zcommon.Contains(unpacked, repo)
|
||||
isRepoStarred := zcommon.Contains(userData.StarredRepos, repo)
|
||||
|
||||
if isRepoStarred {
|
||||
res = repodb.Removed
|
||||
unpacked = zcommon.RemoveFrom(unpacked, repo)
|
||||
userData.StarredRepos = zcommon.RemoveFrom(userData.StarredRepos, repo)
|
||||
} else {
|
||||
res = repodb.Added
|
||||
unpacked = append(unpacked, repo)
|
||||
userData.StarredRepos = append(userData.StarredRepos, repo)
|
||||
}
|
||||
|
||||
var repacked []byte
|
||||
if repacked, err = json.Marshal(unpacked); err != nil {
|
||||
return zerr.ErrCouldNotMarshalStarredRepos
|
||||
}
|
||||
|
||||
err = userBucket.Put([]byte(bolt.StarredReposKey), repacked)
|
||||
err = bdw.setUserData(userid, tx, userData)
|
||||
if err != nil {
|
||||
return zerr.ErrCouldNotPersistData
|
||||
return err
|
||||
}
|
||||
|
||||
repoBuck := tx.Bucket([]byte(bolt.RepoMetadataBucket))
|
||||
@@ -1755,46 +1748,12 @@ func (bdw *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) (repodb.T
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) GetStarredRepos(ctx context.Context) ([]string, error) {
|
||||
starredRepos := make([]string, 0)
|
||||
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return starredRepos, err
|
||||
userData, err := bdw.GetUserData(ctx)
|
||||
if errors.Is(err, zerr.ErrUserDataNotFound) || errors.Is(err, zerr.ErrUserDataNotAllowed) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
|
||||
err = bdw.DB.View(func(tx *bbolt.Tx) error { //nolint:dupl
|
||||
if userid == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
userdb := tx.Bucket([]byte(bolt.UserDataBucket))
|
||||
userBucket := userdb.Bucket([]byte(userid))
|
||||
|
||||
if userBucket == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mdata := userBucket.Get([]byte(bolt.StarredReposKey))
|
||||
if mdata == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(mdata, &starredRepos); err != nil {
|
||||
bdw.Log.Info().Str("user", userid).Err(err).Msg("unmarshal error")
|
||||
|
||||
return zerr.ErrInvalidOldUserStarredRepos
|
||||
}
|
||||
|
||||
if starredRepos == nil {
|
||||
starredRepos = make([]string, 0)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return starredRepos, err
|
||||
return userData.StarredRepos, err
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (repodb.ToggleState, error) {
|
||||
@@ -1815,43 +1774,25 @@ func (bdw *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (repo
|
||||
|
||||
var res repodb.ToggleState
|
||||
|
||||
if err := bdw.DB.Update(func(tx *bbolt.Tx) error { //nolint:dupl
|
||||
userdb := tx.Bucket([]byte(bolt.UserDataBucket))
|
||||
userBucket, err := userdb.CreateBucketIfNotExists([]byte(userid))
|
||||
if err != nil {
|
||||
// this is a serious failure
|
||||
return zerr.ErrUnableToCreateUserBucket
|
||||
if err := bdw.DB.Update(func(transaction *bbolt.Tx) error { //nolint:dupl
|
||||
var userData repodb.UserData
|
||||
|
||||
err := bdw.getUserData(userid, transaction, &userData)
|
||||
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
mdata := userBucket.Get([]byte(bolt.BookmarkedReposKey))
|
||||
unpacked := []string{}
|
||||
if mdata != nil {
|
||||
if err = json.Unmarshal(mdata, &unpacked); err != nil {
|
||||
return zerr.ErrInvalidOldUserBookmarkedRepos
|
||||
}
|
||||
}
|
||||
|
||||
isRepoBookmarked := zcommon.Contains(unpacked, repo)
|
||||
isRepoBookmarked := zcommon.Contains(userData.BookmarkedRepos, repo)
|
||||
|
||||
if isRepoBookmarked {
|
||||
res = repodb.Removed
|
||||
unpacked = zcommon.RemoveFrom(unpacked, repo)
|
||||
userData.BookmarkedRepos = zcommon.RemoveFrom(userData.BookmarkedRepos, repo)
|
||||
} else {
|
||||
res = repodb.Added
|
||||
unpacked = append(unpacked, repo)
|
||||
userData.BookmarkedRepos = append(userData.BookmarkedRepos, repo)
|
||||
}
|
||||
|
||||
var repacked []byte
|
||||
if repacked, err = json.Marshal(unpacked); err != nil {
|
||||
return zerr.ErrCouldNotMarshalBookmarkedRepos
|
||||
}
|
||||
|
||||
err = userBucket.Put([]byte(bolt.BookmarkedReposKey), repacked)
|
||||
if err != nil {
|
||||
return zerr.ErrUnableToCreateUserBucket
|
||||
}
|
||||
|
||||
return nil
|
||||
return bdw.setUserData(userid, transaction, userData)
|
||||
}); err != nil {
|
||||
return repodb.NotChanged, err
|
||||
}
|
||||
@@ -1860,46 +1801,12 @@ func (bdw *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (repo
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) GetBookmarkedRepos(ctx context.Context) ([]string, error) {
|
||||
bookmarkedRepos := []string{}
|
||||
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return bookmarkedRepos, err
|
||||
userData, err := bdw.GetUserData(ctx)
|
||||
if errors.Is(err, zerr.ErrUserDataNotFound) || errors.Is(err, zerr.ErrUserDataNotAllowed) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
|
||||
err = bdw.DB.View(func(tx *bbolt.Tx) error { //nolint:dupl
|
||||
if userid == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
userdb := tx.Bucket([]byte(bolt.UserDataBucket))
|
||||
userBucket := userdb.Bucket([]byte(userid))
|
||||
|
||||
if userBucket == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mdata := userBucket.Get([]byte(bolt.BookmarkedReposKey))
|
||||
if mdata == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(mdata, &bookmarkedRepos); err != nil {
|
||||
bdw.Log.Info().Str("user", userid).Err(err).Msg("unmarshal error")
|
||||
|
||||
return zerr.ErrInvalidOldUserBookmarkedRepos
|
||||
}
|
||||
|
||||
if bookmarkedRepos == nil {
|
||||
bookmarkedRepos = make([]string, 0)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return bookmarkedRepos, err
|
||||
return userData.BookmarkedRepos, err
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) PatchDB() error {
|
||||
@@ -1940,30 +1847,25 @@ func getUserStars(ctx context.Context, transaction *bbolt.Tx) []string {
|
||||
}
|
||||
|
||||
var (
|
||||
userid = localCtx.GetUsernameFromContext(acCtx)
|
||||
starredRepos = []string{}
|
||||
userdb = transaction.Bucket([]byte(bolt.UserDataBucket))
|
||||
userBucket = userdb.Bucket([]byte(userid))
|
||||
userData repodb.UserData
|
||||
userid = localCtx.GetUsernameFromContext(acCtx)
|
||||
userdb = transaction.Bucket([]byte(bolt.UserDataBucket))
|
||||
)
|
||||
|
||||
if userid == "" {
|
||||
if userid == "" || userdb == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
if userBucket == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
mdata := userBucket.Get([]byte(bolt.StarredReposKey))
|
||||
mdata := userdb.Get([]byte(userid))
|
||||
if mdata == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(mdata, &starredRepos); err != nil {
|
||||
if err := json.Unmarshal(mdata, &userData); err != nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
return starredRepos
|
||||
return userData.StarredRepos
|
||||
}
|
||||
|
||||
func getUserBookmarks(ctx context.Context, transaction *bbolt.Tx) []string {
|
||||
@@ -1973,28 +1875,309 @@ func getUserBookmarks(ctx context.Context, transaction *bbolt.Tx) []string {
|
||||
}
|
||||
|
||||
var (
|
||||
userid = localCtx.GetUsernameFromContext(acCtx)
|
||||
bookmarkedRepos = []string{}
|
||||
userdb = transaction.Bucket([]byte(bolt.UserDataBucket))
|
||||
userBucket = userdb.Bucket([]byte(userid))
|
||||
userData repodb.UserData
|
||||
userid = localCtx.GetUsernameFromContext(acCtx)
|
||||
userdb = transaction.Bucket([]byte(bolt.UserDataBucket))
|
||||
)
|
||||
|
||||
if userid == "" {
|
||||
if userid == "" || userdb == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
if userBucket == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
mdata := userBucket.Get([]byte(bolt.BookmarkedReposKey))
|
||||
mdata := userdb.Get([]byte(userid))
|
||||
if mdata == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(mdata, &bookmarkedRepos); err != nil {
|
||||
if err := json.Unmarshal(mdata, &userData); err != nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
return bookmarkedRepos
|
||||
return userData.BookmarkedRepos
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) SetUserGroups(ctx context.Context, groups []string) error {
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
|
||||
if userid == "" {
|
||||
// empty user is anonymous
|
||||
return zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
err = bdw.DB.Update(func(tx *bbolt.Tx) error { //nolint:varnamelen
|
||||
var userData repodb.UserData
|
||||
|
||||
err := bdw.getUserData(userid, tx, &userData)
|
||||
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
userData.Groups = append(userData.Groups, groups...)
|
||||
|
||||
err = bdw.setUserData(userid, tx, userData)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) GetUserGroups(ctx context.Context) ([]string, error) {
|
||||
userData, err := bdw.GetUserData(ctx)
|
||||
|
||||
return userData.Groups, err
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) UpdateUserAPIKeyLastUsed(ctx context.Context, hashedKey string) error {
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
|
||||
if userid == "" {
|
||||
// empty user is anonymous
|
||||
return zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
err = bdw.DB.Update(func(tx *bbolt.Tx) error { //nolint:varnamelen
|
||||
var userData repodb.UserData
|
||||
|
||||
err := bdw.getUserData(userid, tx, &userData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
apiKeyDetails := userData.APIKeys[hashedKey]
|
||||
apiKeyDetails.LastUsed = time.Now()
|
||||
|
||||
userData.APIKeys[hashedKey] = apiKeyDetails
|
||||
|
||||
err = bdw.setUserData(userid, tx, userData)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) AddUserAPIKey(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error {
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
if userid == "" {
|
||||
// empty user is anonymous
|
||||
return zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
err = bdw.DB.Update(func(transaction *bbolt.Tx) error {
|
||||
var userData repodb.UserData
|
||||
|
||||
apiKeysbuck := transaction.Bucket([]byte(bolt.UserAPIKeysBucket))
|
||||
if apiKeysbuck == nil {
|
||||
return zerr.ErrBucketDoesNotExist
|
||||
}
|
||||
|
||||
err := apiKeysbuck.Put([]byte(hashedKey), []byte(userid))
|
||||
if err != nil {
|
||||
return fmt.Errorf("repoDB: error while setting userData for identity %s %w", userid, err)
|
||||
}
|
||||
|
||||
err = bdw.getUserData(userid, transaction, &userData)
|
||||
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
if userData.APIKeys == nil {
|
||||
userData.APIKeys = make(map[string]repodb.APIKeyDetails)
|
||||
}
|
||||
|
||||
userData.APIKeys[hashedKey] = *apiKeyDetails
|
||||
|
||||
err = bdw.setUserData(userid, transaction, userData)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) DeleteUserAPIKey(ctx context.Context, keyID string) error {
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
if userid == "" {
|
||||
// empty user is anonymous
|
||||
return zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
err = bdw.DB.Update(func(transaction *bbolt.Tx) error {
|
||||
var userData repodb.UserData
|
||||
|
||||
apiKeysbuck := transaction.Bucket([]byte(bolt.UserAPIKeysBucket))
|
||||
if apiKeysbuck == nil {
|
||||
return zerr.ErrBucketDoesNotExist
|
||||
}
|
||||
|
||||
err := bdw.getUserData(userid, transaction, &userData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for hash, apiKeyDetails := range userData.APIKeys {
|
||||
if apiKeyDetails.UUID == keyID {
|
||||
delete(userData.APIKeys, hash)
|
||||
|
||||
err := apiKeysbuck.Delete([]byte(hash))
|
||||
if err != nil {
|
||||
return fmt.Errorf("userDB: error while deleting userAPIKey entry for hash %s %w", hash, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bdw.setUserData(userid, transaction, userData)
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) GetUserAPIKeyInfo(hashedKey string) (string, error) {
|
||||
var userid string
|
||||
err := bdw.DB.View(func(tx *bbolt.Tx) error {
|
||||
buck := tx.Bucket([]byte(bolt.UserAPIKeysBucket))
|
||||
if buck == nil {
|
||||
return zerr.ErrBucketDoesNotExist
|
||||
}
|
||||
|
||||
uiBlob := buck.Get([]byte(hashedKey))
|
||||
if len(uiBlob) == 0 {
|
||||
return zerr.ErrUserAPIKeyNotFound
|
||||
}
|
||||
|
||||
userid = string(uiBlob)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return userid, err
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) GetUserData(ctx context.Context) (repodb.UserData, error) {
|
||||
var userData repodb.UserData
|
||||
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return userData, err
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
if userid == "" {
|
||||
// empty user is anonymous
|
||||
return userData, zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
err = bdw.DB.View(func(tx *bbolt.Tx) error {
|
||||
return bdw.getUserData(userid, tx, &userData)
|
||||
})
|
||||
|
||||
return userData, err
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) getUserData(userid string, transaction *bbolt.Tx, res *repodb.UserData) error {
|
||||
buck := transaction.Bucket([]byte(bolt.UserDataBucket))
|
||||
if buck == nil {
|
||||
return zerr.ErrBucketDoesNotExist
|
||||
}
|
||||
|
||||
upBlob := buck.Get([]byte(userid))
|
||||
|
||||
if len(upBlob) == 0 {
|
||||
return zerr.ErrUserDataNotFound
|
||||
}
|
||||
|
||||
err := json.Unmarshal(upBlob, res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) SetUserData(ctx context.Context, userData repodb.UserData) error {
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
if userid == "" {
|
||||
// empty user is anonymous
|
||||
return zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
err = bdw.DB.Update(func(tx *bbolt.Tx) error {
|
||||
return bdw.setUserData(userid, tx, userData)
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) setUserData(userid string, transaction *bbolt.Tx, userData repodb.UserData) error {
|
||||
buck := transaction.Bucket([]byte(bolt.UserDataBucket))
|
||||
if buck == nil {
|
||||
return zerr.ErrBucketDoesNotExist
|
||||
}
|
||||
|
||||
upBlob, err := json.Marshal(userData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = buck.Put([]byte(userid), upBlob)
|
||||
if err != nil {
|
||||
return fmt.Errorf("repoDB: error while setting userData for identity %s %w", userid, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bdw *DBWrapper) DeleteUserData(ctx context.Context) error {
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
if userid == "" {
|
||||
// empty user is anonymous
|
||||
return zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
err = bdw.DB.Update(func(tx *bbolt.Tx) error {
|
||||
buck := tx.Bucket([]byte(bolt.UserDataBucket))
|
||||
if buck == nil {
|
||||
return zerr.ErrBucketDoesNotExist
|
||||
}
|
||||
|
||||
err := buck.Delete([]byte(userid))
|
||||
if err != nil {
|
||||
return fmt.Errorf("repoDB: error while deleting userData for identity %s %w", userid, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,7 +2,10 @@ package bolt_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/opencontainers/go-digest"
|
||||
@@ -10,6 +13,7 @@ import (
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"go.etcd.io/bbolt"
|
||||
|
||||
zerr "zotregistry.io/zot/errors"
|
||||
"zotregistry.io/zot/pkg/log"
|
||||
"zotregistry.io/zot/pkg/meta/bolt"
|
||||
"zotregistry.io/zot/pkg/meta/repodb"
|
||||
@@ -21,7 +25,6 @@ import (
|
||||
|
||||
func TestWrapperErrors(t *testing.T) {
|
||||
Convey("Errors", t, func() {
|
||||
ctx := context.Background()
|
||||
tmpDir := t.TempDir()
|
||||
boltDBParams := bolt.DBParameters{RootDir: tmpDir}
|
||||
boltDriver, err := bolt.GetBoltDriver(boltDBParams)
|
||||
@@ -41,6 +44,231 @@ func TestWrapperErrors(t *testing.T) {
|
||||
repoMetaBlob, err := json.Marshal(repoMeta)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "test",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
Convey("AddUserAPIKey", func() {
|
||||
Convey("no userid found", func() {
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
err = boltdbWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{})
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
err = boltdbWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{})
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.DeleteBucket([]byte(bolt.UserDataBucket))
|
||||
})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
err = boltdbWrapper.AddUserAPIKey(ctx, "test", &repodb.APIKeyDetails{})
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.DeleteBucket([]byte(bolt.UserAPIKeysBucket))
|
||||
})
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
err = boltdbWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{})
|
||||
So(err, ShouldEqual, zerr.ErrBucketDoesNotExist)
|
||||
})
|
||||
|
||||
Convey("UpdateUserAPIKey", func() {
|
||||
err = boltdbWrapper.UpdateUserAPIKeyLastUsed(ctx, "")
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
err = boltdbWrapper.UpdateUserAPIKeyLastUsed(ctx, "") //nolint: contextcheck
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("DeleteUserAPIKey", func() {
|
||||
err = boltdbWrapper.SetUserData(ctx, repodb.UserData{})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
err = boltdbWrapper.AddUserAPIKey(ctx, "hashedKey", &repodb.APIKeyDetails{})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
Convey("no such bucket", func() {
|
||||
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.DeleteBucket([]byte(bolt.UserAPIKeysBucket))
|
||||
})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "test",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err = boltdbWrapper.DeleteUserAPIKey(ctx, "")
|
||||
So(err, ShouldEqual, zerr.ErrBucketDoesNotExist)
|
||||
})
|
||||
|
||||
Convey("userdata not found", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "test",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
err := boltdbWrapper.DeleteUserData(ctx)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
err = boltdbWrapper.DeleteUserAPIKey(ctx, "")
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) //nolint: contextcheck
|
||||
|
||||
err = boltdbWrapper.DeleteUserAPIKey(ctx, "test") //nolint: contextcheck
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.DeleteBucket([]byte(bolt.UserDataBucket))
|
||||
})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
err = boltdbWrapper.DeleteUserAPIKey(ctx, "") //nolint: contextcheck
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetUserAPIKeyInfo", func() {
|
||||
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.DeleteBucket([]byte(bolt.UserAPIKeysBucket))
|
||||
})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = boltdbWrapper.GetUserAPIKeyInfo("")
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetUserData", func() {
|
||||
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
buck := tx.Bucket([]byte(bolt.UserDataBucket))
|
||||
So(buck, ShouldNotBeNil)
|
||||
|
||||
return buck.Put([]byte("test"), []byte("dsa8"))
|
||||
})
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = boltdbWrapper.GetUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.DeleteBucket([]byte(bolt.UserAPIKeysBucket))
|
||||
})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = boltdbWrapper.GetUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("SetUserData", func() {
|
||||
acCtx = localCtx.AccessControlContext{
|
||||
Username: "",
|
||||
}
|
||||
|
||||
ctx = context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err = boltdbWrapper.SetUserData(ctx, repodb.UserData{})
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
buff := make([]byte, int(math.Ceil(float64(1000000)/float64(1.33333333333))))
|
||||
_, err := rand.Read(buff)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
longString := base64.RawURLEncoding.EncodeToString(buff)
|
||||
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: longString,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err = boltdbWrapper.SetUserData(ctx, repodb.UserData{}) //nolint: contextcheck
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.DeleteBucket([]byte(bolt.UserDataBucket))
|
||||
})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
acCtx = localCtx.AccessControlContext{
|
||||
Username: "test",
|
||||
}
|
||||
|
||||
ctx = context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err = boltdbWrapper.SetUserData(ctx, repodb.UserData{}) //nolint: contextcheck
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("DeleteUserData", func() {
|
||||
acCtx = localCtx.AccessControlContext{
|
||||
Username: "",
|
||||
}
|
||||
|
||||
ctx = context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err = boltdbWrapper.DeleteUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.DeleteBucket([]byte(bolt.UserDataBucket))
|
||||
})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
acCtx = localCtx.AccessControlContext{
|
||||
Username: "test",
|
||||
}
|
||||
|
||||
ctx = context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err = boltdbWrapper.DeleteUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetUserGroups and SetUserGroups", func() {
|
||||
acCtx = localCtx.AccessControlContext{
|
||||
Username: "",
|
||||
}
|
||||
|
||||
ctx = context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
_, err := boltdbWrapper.GetUserGroups(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
err = boltdbWrapper.SetUserGroups(ctx, []string{})
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetManifestData", func() {
|
||||
err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
dataBuck := tx.Bucket([]byte(bolt.ManifestDataBucket))
|
||||
@@ -732,60 +960,6 @@ func TestWrapperErrors(t *testing.T) {
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("ToggleStarRepo, getting StarredRepoKey from bucket fails", func() {
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
ReadGlobPatterns: map[string]bool{
|
||||
"repo": true,
|
||||
},
|
||||
Username: "username",
|
||||
}
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
userdb, err := tx.CreateBucketIfNotExists([]byte(bolt.UserDataBucket))
|
||||
So(err, ShouldBeNil)
|
||||
userBucket, err := userdb.CreateBucketIfNotExists([]byte(acCtx.Username))
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
err = userBucket.Put([]byte(bolt.StarredReposKey), []byte("bad array"))
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
return nil
|
||||
})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = boltdbWrapper.ToggleStarRepo(ctx, "repo")
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("ToggleBookmarkRepo, unmarshal error", func() {
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
ReadGlobPatterns: map[string]bool{
|
||||
"repo": true,
|
||||
},
|
||||
Username: "username",
|
||||
}
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
userdb, err := tx.CreateBucketIfNotExists([]byte(bolt.UserDataBucket))
|
||||
So(err, ShouldBeNil)
|
||||
userBucket, err := userdb.CreateBucketIfNotExists([]byte(acCtx.Username))
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
err = userBucket.Put([]byte(bolt.BookmarkedReposKey), []byte("bad array"))
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
return nil
|
||||
})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = boltdbWrapper.ToggleBookmarkRepo(ctx, "repo")
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("ToggleStarRepo, no repoMeta found", func() {
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
ReadGlobPatterns: map[string]bool{
|
||||
@@ -832,6 +1006,73 @@ func TestWrapperErrors(t *testing.T) {
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetUserData bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
_, err := boltdbWrapper.GetUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("SetUserData bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
err := boltdbWrapper.SetUserData(ctx, repodb.UserData{})
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetUserGroups bad context errors", func() {
|
||||
_, err := boltdbWrapper.GetUserGroups(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
_, err = boltdbWrapper.GetUserGroups(ctx) //nolint: contextcheck
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("SetUserGroups bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
err := boltdbWrapper.SetUserGroups(ctx, []string{})
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("AddUserAPIKey bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
err := boltdbWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{})
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("DeleteUserAPIKey bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
err := boltdbWrapper.DeleteUserAPIKey(ctx, "")
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("UpdateUserAPIKeyLastUsed bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
err := boltdbWrapper.UpdateUserAPIKeyLastUsed(ctx, "")
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("DeleteUserData bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
err := boltdbWrapper.DeleteUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetStarredRepos bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
@@ -840,60 +1081,6 @@ func TestWrapperErrors(t *testing.T) {
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetStarredRepos user data unmarshal error", func() {
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
ReadGlobPatterns: map[string]bool{
|
||||
"repo": true,
|
||||
},
|
||||
Username: "username",
|
||||
}
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
userdb, err := tx.CreateBucketIfNotExists([]byte(bolt.UserDataBucket))
|
||||
So(err, ShouldBeNil)
|
||||
userBucket, err := userdb.CreateBucketIfNotExists([]byte(acCtx.Username))
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
err = userBucket.Put([]byte(bolt.StarredReposKey), []byte("bad array"))
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
return nil
|
||||
})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = boltdbWrapper.GetStarredRepos(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetBookmarkedRepos user data unmarshal error", func() {
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
ReadGlobPatterns: map[string]bool{
|
||||
"repo": true,
|
||||
},
|
||||
Username: "username",
|
||||
}
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
|
||||
userdb, err := tx.CreateBucketIfNotExists([]byte(bolt.UserDataBucket))
|
||||
So(err, ShouldBeNil)
|
||||
userBucket, err := userdb.CreateBucketIfNotExists([]byte(acCtx.Username))
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
err = userBucket.Put([]byte(bolt.BookmarkedReposKey), []byte("bad array"))
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
return nil
|
||||
})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = boltdbWrapper.GetBookmarkedRepos(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetBookmarkedRepos bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
@@ -31,6 +31,7 @@ func TestWrapperErrors(t *testing.T) {
|
||||
manifestDataTablename := "ManifestDataTable" + uuid.String()
|
||||
indexDataTablename := "IndexDataTable" + uuid.String()
|
||||
userDataTablename := "UserDataTable" + uuid.String()
|
||||
apiKeyTablename := "ApiKeyTable" + uuid.String()
|
||||
|
||||
versionTablename := "Version" + uuid.String()
|
||||
|
||||
@@ -58,6 +59,7 @@ func TestWrapperErrors(t *testing.T) {
|
||||
IndexDataTablename: indexDataTablename,
|
||||
VersionTablename: versionTablename,
|
||||
UserDataTablename: userDataTablename,
|
||||
APIKeyTablename: apiKeyTablename,
|
||||
Patches: version.GetDynamoDBPatches(),
|
||||
Log: log.Logger{Logger: zerolog.New(os.Stdout)},
|
||||
}
|
||||
@@ -74,6 +76,9 @@ func TestWrapperErrors(t *testing.T) {
|
||||
|
||||
err = dynamoWrapper.createVersionTable()
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
err = dynamoWrapper.createAPIKeyTable()
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("Delete table errors", t, func() {
|
||||
|
||||
@@ -43,6 +43,7 @@ func TestIterator(t *testing.T) {
|
||||
versionTablename := "Version" + uuid.String()
|
||||
indexDataTablename := "IndexDataTable" + uuid.String()
|
||||
userDataTablename := "UserDataTable" + uuid.String()
|
||||
apiKeyTablename := "ApiKeyTable" + uuid.String()
|
||||
|
||||
log := log.NewLogger("debug", "")
|
||||
|
||||
@@ -54,6 +55,7 @@ func TestIterator(t *testing.T) {
|
||||
ManifestDataTablename: manifestDataTablename,
|
||||
IndexDataTablename: indexDataTablename,
|
||||
VersionTablename: versionTablename,
|
||||
APIKeyTablename: apiKeyTablename,
|
||||
UserDataTablename: userDataTablename,
|
||||
}
|
||||
client, err := dynamo.GetDynamoClient(params)
|
||||
@@ -144,8 +146,8 @@ func TestWrapperErrors(t *testing.T) {
|
||||
versionTablename := "Version" + uuid.String()
|
||||
indexDataTablename := "IndexDataTable" + uuid.String()
|
||||
userDataTablename := "UserDataTable" + uuid.String()
|
||||
|
||||
ctx := context.Background()
|
||||
apiKeyTablename := "ApiKeyTable" + uuid.String()
|
||||
wrongTableName := "WRONG Tables"
|
||||
|
||||
log := log.NewLogger("debug", "")
|
||||
|
||||
@@ -157,6 +159,7 @@ func TestWrapperErrors(t *testing.T) {
|
||||
ManifestDataTablename: manifestDataTablename,
|
||||
IndexDataTablename: indexDataTablename,
|
||||
UserDataTablename: userDataTablename,
|
||||
APIKeyTablename: apiKeyTablename,
|
||||
VersionTablename: versionTablename,
|
||||
}
|
||||
client, err := dynamo.GetDynamoClient(params) //nolint:contextcheck
|
||||
@@ -168,6 +171,61 @@ func TestWrapperErrors(t *testing.T) {
|
||||
So(dynamoWrapper.ResetManifestDataTable(), ShouldBeNil) //nolint:contextcheck
|
||||
So(dynamoWrapper.ResetRepoMetaTable(), ShouldBeNil) //nolint:contextcheck
|
||||
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "test",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
Convey("SetUserData", func() {
|
||||
hashKey := "id"
|
||||
apiKeys := make(map[string]repodb.APIKeyDetails)
|
||||
apiKeyDetails := repodb.APIKeyDetails{
|
||||
Label: "apiKey",
|
||||
Scopes: []string{"repo"},
|
||||
UUID: hashKey,
|
||||
}
|
||||
|
||||
apiKeys[hashKey] = apiKeyDetails
|
||||
|
||||
userProfileSrc := repodb.UserData{
|
||||
Groups: []string{"group1", "group2"},
|
||||
APIKeys: apiKeys,
|
||||
}
|
||||
|
||||
err := dynamoWrapper.SetUserData(ctx, userProfileSrc)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err = dynamoWrapper.SetUserData(ctx, repodb.UserData{}) //nolint: contextcheck
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("DeleteUserData", func() {
|
||||
err := dynamoWrapper.DeleteUserData(ctx)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err = dynamoWrapper.DeleteUserData(ctx) //nolint: contextcheck
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("ToggleBookmarkRepo no access", func() {
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
ReadGlobPatterns: map[string]bool{
|
||||
@@ -290,17 +348,17 @@ func TestWrapperErrors(t *testing.T) {
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetUserMeta bad context", func() {
|
||||
Convey("GetUserData bad context", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
userData, err := dynamoWrapper.GetUserMeta(ctx)
|
||||
userData, err := dynamoWrapper.GetUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
So(userData.BookmarkedRepos, ShouldBeEmpty)
|
||||
So(userData.StarredRepos, ShouldBeEmpty)
|
||||
})
|
||||
|
||||
Convey("GetUserMeta client error", func() {
|
||||
Convey("GetUserData client error", func() {
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
ReadGlobPatterns: map[string]bool{
|
||||
"repo": true,
|
||||
@@ -312,7 +370,7 @@ func TestWrapperErrors(t *testing.T) {
|
||||
|
||||
dynamoWrapper.UserDataTablename = badTablename
|
||||
|
||||
_, err := dynamoWrapper.GetUserMeta(ctx)
|
||||
_, err := dynamoWrapper.GetUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
@@ -329,27 +387,155 @@ func TestWrapperErrors(t *testing.T) {
|
||||
err := setBadUserData(dynamoWrapper.Client, userDataTablename, acCtx.Username)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = dynamoWrapper.GetUserMeta(ctx)
|
||||
_, err = dynamoWrapper.GetUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("SetUserMeta bad context", func() {
|
||||
Convey("SetUserData bad context", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
err := dynamoWrapper.SetUserMeta(ctx, repodb.UserData{})
|
||||
err := dynamoWrapper.SetUserData(ctx, repodb.UserData{})
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetUserData bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
_, err := dynamoWrapper.GetUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("SetUserData bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
err := dynamoWrapper.SetUserData(ctx, repodb.UserData{})
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("AddUserAPIKey bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
err := dynamoWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{})
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("DeleteUserAPIKey bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
err := dynamoWrapper.DeleteUserAPIKey(ctx, "")
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("UpdateUserAPIKeyLastUsed bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
err := dynamoWrapper.UpdateUserAPIKeyLastUsed(ctx, "")
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("DeleteUserData bad context errors", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
|
||||
|
||||
err := dynamoWrapper.DeleteUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("DeleteUserAPIKey returns nil", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "email",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
apiKeyDetails := make(map[string]repodb.APIKeyDetails)
|
||||
apiKeyDetails["id"] = repodb.APIKeyDetails{
|
||||
UUID: "id",
|
||||
}
|
||||
err := dynamoWrapper.SetUserData(ctx, repodb.UserData{
|
||||
APIKeys: apiKeyDetails,
|
||||
})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
dynamoWrapper.APIKeyTablename = wrongTableName
|
||||
err = dynamoWrapper.DeleteUserAPIKey(ctx, "id")
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("AddUserAPIKey", func() {
|
||||
Convey("no userid found", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err = dynamoWrapper.AddUserAPIKey(ctx, "key", &repodb.APIKeyDetails{})
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "email",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err := dynamoWrapper.AddUserAPIKey(ctx, "key", &repodb.APIKeyDetails{})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
dynamoWrapper.APIKeyTablename = wrongTableName
|
||||
err = dynamoWrapper.AddUserAPIKey(ctx, "key", &repodb.APIKeyDetails{})
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetUserAPIKeyInfo", func() {
|
||||
dynamoWrapper.APIKeyTablename = wrongTableName
|
||||
_, err := dynamoWrapper.GetUserAPIKeyInfo("key")
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetUserData", func() {
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
_, err := dynamoWrapper.GetUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
acCtx = localCtx.AccessControlContext{
|
||||
Username: "email",
|
||||
}
|
||||
|
||||
ctx = context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
dynamoWrapper.UserDataTablename = wrongTableName
|
||||
_, err = dynamoWrapper.GetUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("SetManifestData", func() {
|
||||
dynamoWrapper.ManifestDataTablename = "WRONG tables"
|
||||
dynamoWrapper.ManifestDataTablename = wrongTableName
|
||||
|
||||
err := dynamoWrapper.SetManifestData("dig", repodb.ManifestData{})
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
|
||||
Convey("GetManifestData", func() {
|
||||
dynamoWrapper.ManifestDataTablename = "WRONG table"
|
||||
dynamoWrapper.ManifestDataTablename = wrongTableName
|
||||
|
||||
_, err := dynamoWrapper.GetManifestData("dig")
|
||||
So(err, ShouldNotBeNil)
|
||||
@@ -364,7 +550,7 @@ func TestWrapperErrors(t *testing.T) {
|
||||
})
|
||||
|
||||
Convey("GetIndexData", func() {
|
||||
dynamoWrapper.IndexDataTablename = "WRONG table"
|
||||
dynamoWrapper.IndexDataTablename = wrongTableName
|
||||
|
||||
_, err := dynamoWrapper.GetIndexData("dig")
|
||||
So(err, ShouldNotBeNil)
|
||||
@@ -1091,6 +1277,7 @@ func TestWrapperErrors(t *testing.T) {
|
||||
ManifestDataTablename: manifestDataTablename,
|
||||
IndexDataTablename: indexDataTablename,
|
||||
UserDataTablename: userDataTablename,
|
||||
APIKeyTablename: apiKeyTablename,
|
||||
VersionTablename: versionTablename,
|
||||
}
|
||||
client, err := dynamo.GetDynamoClient(params)
|
||||
@@ -1106,6 +1293,7 @@ func TestWrapperErrors(t *testing.T) {
|
||||
ManifestDataTablename: "",
|
||||
IndexDataTablename: indexDataTablename,
|
||||
UserDataTablename: userDataTablename,
|
||||
APIKeyTablename: apiKeyTablename,
|
||||
VersionTablename: versionTablename,
|
||||
}
|
||||
client, err = dynamo.GetDynamoClient(params)
|
||||
@@ -1121,6 +1309,7 @@ func TestWrapperErrors(t *testing.T) {
|
||||
ManifestDataTablename: manifestDataTablename,
|
||||
IndexDataTablename: "",
|
||||
UserDataTablename: userDataTablename,
|
||||
APIKeyTablename: apiKeyTablename,
|
||||
VersionTablename: versionTablename,
|
||||
}
|
||||
client, err = dynamo.GetDynamoClient(params)
|
||||
@@ -1136,6 +1325,7 @@ func TestWrapperErrors(t *testing.T) {
|
||||
ManifestDataTablename: manifestDataTablename,
|
||||
IndexDataTablename: indexDataTablename,
|
||||
UserDataTablename: userDataTablename,
|
||||
APIKeyTablename: apiKeyTablename,
|
||||
VersionTablename: "",
|
||||
}
|
||||
client, err = dynamo.GetDynamoClient(params)
|
||||
@@ -1150,8 +1340,41 @@ func TestWrapperErrors(t *testing.T) {
|
||||
RepoMetaTablename: repoMetaTablename,
|
||||
ManifestDataTablename: manifestDataTablename,
|
||||
IndexDataTablename: indexDataTablename,
|
||||
UserDataTablename: "",
|
||||
VersionTablename: versionTablename,
|
||||
UserDataTablename: userDataTablename,
|
||||
APIKeyTablename: apiKeyTablename,
|
||||
}
|
||||
client, err = dynamo.GetDynamoClient(params)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = dynamoWrapper.NewDynamoDBWrapper(client, params, log)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
params = dynamo.DBDriverParameters{ //nolint:contextcheck
|
||||
Endpoint: endpoint,
|
||||
Region: region,
|
||||
RepoMetaTablename: repoMetaTablename,
|
||||
ManifestDataTablename: manifestDataTablename,
|
||||
IndexDataTablename: indexDataTablename,
|
||||
VersionTablename: versionTablename,
|
||||
UserDataTablename: "",
|
||||
APIKeyTablename: apiKeyTablename,
|
||||
}
|
||||
client, err = dynamo.GetDynamoClient(params)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = dynamoWrapper.NewDynamoDBWrapper(client, params, log)
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
params = dynamo.DBDriverParameters{ //nolint:contextcheck
|
||||
Endpoint: endpoint,
|
||||
Region: region,
|
||||
RepoMetaTablename: repoMetaTablename,
|
||||
ManifestDataTablename: manifestDataTablename,
|
||||
IndexDataTablename: indexDataTablename,
|
||||
VersionTablename: versionTablename,
|
||||
UserDataTablename: userDataTablename,
|
||||
APIKeyTablename: "",
|
||||
}
|
||||
client, err = dynamo.GetDynamoClient(params)
|
||||
So(err, ShouldBeNil)
|
||||
@@ -1250,7 +1473,7 @@ func setBadUserData(client *dynamodb.Client, userDataTablename, userID string) e
|
||||
":UserData": userAttributeValue,
|
||||
},
|
||||
Key: map[string]types.AttributeValue{
|
||||
"UserID": &types.AttributeValueMemberS{
|
||||
"Identity": &types.AttributeValueMemberS{
|
||||
Value: userID,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -30,6 +30,7 @@ var errRepodb = errors.New("repodb: error while constructing manifest meta")
|
||||
|
||||
type DBWrapper struct {
|
||||
Client *dynamodb.Client
|
||||
APIKeyTablename string
|
||||
RepoMetaTablename string
|
||||
IndexDataTablename string
|
||||
ManifestDataTablename string
|
||||
@@ -47,6 +48,7 @@ func NewDynamoDBWrapper(client *dynamodb.Client, params dynamo.DBDriverParameter
|
||||
IndexDataTablename: params.IndexDataTablename,
|
||||
VersionTablename: params.VersionTablename,
|
||||
UserDataTablename: params.UserDataTablename,
|
||||
APIKeyTablename: params.APIKeyTablename,
|
||||
Patches: version.GetDynamoDBPatches(),
|
||||
Log: log,
|
||||
}
|
||||
@@ -76,6 +78,11 @@ func NewDynamoDBWrapper(client *dynamodb.Client, params dynamo.DBDriverParameter
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = dynamoWrapper.createAPIKeyTable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Using the Config value, create the DynamoDB client
|
||||
return &dynamoWrapper, nil
|
||||
}
|
||||
@@ -580,13 +587,13 @@ func (dwr *DBWrapper) GetUserRepoMeta(ctx context.Context, repo string) (repodb.
|
||||
return repodb.RepoMetadata{}, err
|
||||
}
|
||||
|
||||
userMeta, err := dwr.GetUserMeta(ctx)
|
||||
userData, err := dwr.GetUserData(ctx)
|
||||
if err != nil {
|
||||
return repodb.RepoMetadata{}, err
|
||||
}
|
||||
|
||||
repoMeta.IsBookmarked = zcommon.Contains(userMeta.BookmarkedRepos, repo)
|
||||
repoMeta.IsStarred = zcommon.Contains(userMeta.StarredRepos, repo)
|
||||
repoMeta.IsBookmarked = zcommon.Contains(userData.BookmarkedRepos, repo)
|
||||
repoMeta.IsStarred = zcommon.Contains(userData.StarredRepos, repo)
|
||||
|
||||
return repoMeta, nil
|
||||
}
|
||||
@@ -1802,7 +1809,7 @@ func (dwr *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (
|
||||
return res, zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
userMeta, err := dwr.GetUserMeta(ctx)
|
||||
userData, err := dwr.GetUserData(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, zerr.ErrUserDataNotFound) {
|
||||
return repodb.NotChanged, nil
|
||||
@@ -1811,16 +1818,16 @@ func (dwr *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (
|
||||
return res, err
|
||||
}
|
||||
|
||||
if !zcommon.Contains(userMeta.BookmarkedRepos, repo) {
|
||||
userMeta.BookmarkedRepos = append(userMeta.BookmarkedRepos, repo)
|
||||
if !zcommon.Contains(userData.BookmarkedRepos, repo) {
|
||||
userData.BookmarkedRepos = append(userData.BookmarkedRepos, repo)
|
||||
res = repodb.Added
|
||||
} else {
|
||||
userMeta.BookmarkedRepos = zcommon.RemoveFrom(userMeta.BookmarkedRepos, repo)
|
||||
userData.BookmarkedRepos = zcommon.RemoveFrom(userData.BookmarkedRepos, repo)
|
||||
res = repodb.Removed
|
||||
}
|
||||
|
||||
if res != repodb.NotChanged {
|
||||
err = dwr.SetUserMeta(ctx, userMeta)
|
||||
err = dwr.SetUserData(ctx, userData)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -1833,9 +1840,9 @@ func (dwr *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (
|
||||
}
|
||||
|
||||
func (dwr *DBWrapper) GetBookmarkedRepos(ctx context.Context) ([]string, error) {
|
||||
userMeta, err := dwr.GetUserMeta(ctx)
|
||||
userMeta, err := dwr.GetUserData(ctx)
|
||||
|
||||
if errors.Is(err, zerr.ErrUserDataNotFound) {
|
||||
if errors.Is(err, zerr.ErrUserDataNotFound) || errors.Is(err, zerr.ErrUserDataNotAllowed) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
@@ -1863,7 +1870,7 @@ func (dwr *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) (
|
||||
return res, zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
userData, err := dwr.GetUserMeta(ctx)
|
||||
userData, err := dwr.GetUserData(ctx)
|
||||
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
|
||||
return res, err
|
||||
}
|
||||
@@ -1902,21 +1909,21 @@ func (dwr *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) (
|
||||
_, err = dwr.Client.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{
|
||||
TransactItems: []types.TransactWriteItem{
|
||||
{
|
||||
// Update User Meta
|
||||
// Update User Profile
|
||||
Update: &types.Update{
|
||||
ExpressionAttributeNames: map[string]string{
|
||||
"#UM": "UserData",
|
||||
"#UP": "UserData",
|
||||
},
|
||||
ExpressionAttributeValues: map[string]types.AttributeValue{
|
||||
":UserData": userAttributeValue,
|
||||
},
|
||||
Key: map[string]types.AttributeValue{
|
||||
"UserID": &types.AttributeValueMemberS{
|
||||
"Identity": &types.AttributeValueMemberS{
|
||||
Value: userid,
|
||||
},
|
||||
},
|
||||
TableName: aws.String(dwr.UserDataTablename),
|
||||
UpdateExpression: aws.String("SET #UM = :UserData"),
|
||||
UpdateExpression: aws.String("SET #UP = :UserData"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -1948,64 +1955,27 @@ func (dwr *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) (
|
||||
}
|
||||
|
||||
func (dwr *DBWrapper) GetStarredRepos(ctx context.Context) ([]string, error) {
|
||||
userMeta, err := dwr.GetUserMeta(ctx)
|
||||
userMeta, err := dwr.GetUserData(ctx)
|
||||
|
||||
if errors.Is(err, zerr.ErrUserDataNotFound) {
|
||||
if errors.Is(err, zerr.ErrUserDataNotFound) || errors.Is(err, zerr.ErrUserDataNotAllowed) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
return userMeta.StarredRepos, err
|
||||
}
|
||||
|
||||
func (dwr *DBWrapper) GetUserMeta(ctx context.Context) (repodb.UserData, error) {
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return repodb.UserData{}, err
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
|
||||
if userid == "" {
|
||||
// empty user is anonymous, it has no data
|
||||
return repodb.UserData{}, nil
|
||||
}
|
||||
|
||||
resp, err := dwr.Client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(dwr.UserDataTablename),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"UserID": &types.AttributeValueMemberS{Value: userid},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return repodb.UserData{}, err
|
||||
}
|
||||
|
||||
if resp.Item == nil {
|
||||
return repodb.UserData{}, zerr.ErrUserDataNotFound
|
||||
}
|
||||
|
||||
var userMeta repodb.UserData
|
||||
|
||||
err = attributevalue.Unmarshal(resp.Item["UserData"], &userMeta)
|
||||
if err != nil {
|
||||
return repodb.UserData{}, err
|
||||
}
|
||||
|
||||
return userMeta, nil
|
||||
}
|
||||
|
||||
func (dwr *DBWrapper) createUserDataTable() error {
|
||||
_, err := dwr.Client.CreateTable(context.Background(), &dynamodb.CreateTableInput{
|
||||
TableName: aws.String(dwr.UserDataTablename),
|
||||
AttributeDefinitions: []types.AttributeDefinition{
|
||||
{
|
||||
AttributeName: aws.String("UserID"),
|
||||
AttributeName: aws.String("Identity"),
|
||||
AttributeType: types.ScalarAttributeTypeS,
|
||||
},
|
||||
},
|
||||
KeySchema: []types.KeySchemaElement{
|
||||
{
|
||||
AttributeName: aws.String("UserID"),
|
||||
AttributeName: aws.String("Identity"),
|
||||
KeyType: types.KeyTypeHash,
|
||||
},
|
||||
},
|
||||
@@ -2019,38 +1989,279 @@ func (dwr *DBWrapper) createUserDataTable() error {
|
||||
return dwr.waitTableToBeCreated(dwr.UserDataTablename)
|
||||
}
|
||||
|
||||
func (dwr *DBWrapper) SetUserMeta(ctx context.Context, userMeta repodb.UserData) error {
|
||||
func (dwr DBWrapper) createAPIKeyTable() error {
|
||||
_, err := dwr.Client.CreateTable(context.Background(), &dynamodb.CreateTableInput{
|
||||
TableName: aws.String(dwr.APIKeyTablename),
|
||||
AttributeDefinitions: []types.AttributeDefinition{
|
||||
{
|
||||
AttributeName: aws.String("HashedKey"),
|
||||
AttributeType: types.ScalarAttributeTypeS,
|
||||
},
|
||||
},
|
||||
KeySchema: []types.KeySchemaElement{
|
||||
{
|
||||
AttributeName: aws.String("HashedKey"),
|
||||
KeyType: types.KeyTypeHash,
|
||||
},
|
||||
},
|
||||
BillingMode: types.BillingModePayPerRequest,
|
||||
})
|
||||
|
||||
if err != nil && !strings.Contains(err.Error(), "Table already exists") {
|
||||
return err
|
||||
}
|
||||
|
||||
return dwr.waitTableToBeCreated(dwr.APIKeyTablename)
|
||||
}
|
||||
|
||||
func (dwr DBWrapper) SetUserGroups(ctx context.Context, groups []string) error {
|
||||
userData, err := dwr.GetUserData(ctx)
|
||||
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
userData.Groups = append(userData.Groups, groups...)
|
||||
|
||||
return dwr.SetUserData(ctx, userData)
|
||||
}
|
||||
|
||||
func (dwr DBWrapper) GetUserGroups(ctx context.Context) ([]string, error) {
|
||||
userData, err := dwr.GetUserData(ctx)
|
||||
|
||||
return userData.Groups, err
|
||||
}
|
||||
|
||||
func (dwr DBWrapper) UpdateUserAPIKeyLastUsed(ctx context.Context, hashedKey string) error {
|
||||
userData, err := dwr.GetUserData(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
apiKeyDetails := userData.APIKeys[hashedKey]
|
||||
apiKeyDetails.LastUsed = time.Now()
|
||||
|
||||
userData.APIKeys[hashedKey] = apiKeyDetails
|
||||
|
||||
err = dwr.SetUserData(ctx, userData)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dwr DBWrapper) AddUserAPIKey(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error {
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
|
||||
if userid == "" {
|
||||
// empty user is anonymous, it has no data
|
||||
// empty user is anonymous
|
||||
return zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
userAttributeValue, err := attributevalue.Marshal(userMeta)
|
||||
userData, err := dwr.GetUserData(ctx)
|
||||
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
|
||||
return fmt.Errorf("repoDB: error while getting userData for identity %s %w", userid, err)
|
||||
}
|
||||
|
||||
if userData.APIKeys == nil {
|
||||
userData.APIKeys = make(map[string]repodb.APIKeyDetails)
|
||||
}
|
||||
|
||||
userData.APIKeys[hashedKey] = *apiKeyDetails
|
||||
|
||||
userAttributeValue, err := attributevalue.Marshal(userData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = dwr.Client.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{
|
||||
TransactItems: []types.TransactWriteItem{
|
||||
{
|
||||
// Update UserData
|
||||
Update: &types.Update{
|
||||
ExpressionAttributeNames: map[string]string{
|
||||
"#UP": "UserData",
|
||||
},
|
||||
ExpressionAttributeValues: map[string]types.AttributeValue{
|
||||
":UserData": userAttributeValue,
|
||||
},
|
||||
Key: map[string]types.AttributeValue{
|
||||
"Identity": &types.AttributeValueMemberS{
|
||||
Value: userid,
|
||||
},
|
||||
},
|
||||
TableName: aws.String(dwr.UserDataTablename),
|
||||
UpdateExpression: aws.String("SET #UP = :UserData"),
|
||||
},
|
||||
},
|
||||
{
|
||||
// Update APIKeyInfo
|
||||
Update: &types.Update{
|
||||
ExpressionAttributeNames: map[string]string{
|
||||
"#EM": "Identity",
|
||||
},
|
||||
ExpressionAttributeValues: map[string]types.AttributeValue{
|
||||
":Identity": &types.AttributeValueMemberS{Value: userid},
|
||||
},
|
||||
Key: map[string]types.AttributeValue{
|
||||
"HashedKey": &types.AttributeValueMemberS{
|
||||
Value: hashedKey,
|
||||
},
|
||||
},
|
||||
TableName: aws.String(dwr.APIKeyTablename),
|
||||
UpdateExpression: aws.String("SET #EM = :Identity"),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dwr DBWrapper) DeleteUserAPIKey(ctx context.Context, keyID string) error {
|
||||
userData, err := dwr.GetUserData(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("repoDB: error while getting userData %w", err)
|
||||
}
|
||||
|
||||
for hash, apiKeyDetails := range userData.APIKeys {
|
||||
if apiKeyDetails.UUID == keyID {
|
||||
delete(userData.APIKeys, hash)
|
||||
|
||||
_, err = dwr.Client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
|
||||
TableName: aws.String(dwr.APIKeyTablename),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"HashedKey": &types.AttributeValueMemberS{Value: hash},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("repoDB: error while deleting userAPIKey entry for hash %s %w", hash, err)
|
||||
}
|
||||
|
||||
err := dwr.SetUserData(ctx, userData)
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dwr DBWrapper) GetUserAPIKeyInfo(hashedKey string) (string, error) {
|
||||
var userid string
|
||||
|
||||
resp, err := dwr.Client.GetItem(context.Background(), &dynamodb.GetItemInput{
|
||||
TableName: aws.String(dwr.APIKeyTablename),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"HashedKey": &types.AttributeValueMemberS{Value: hashedKey},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if resp.Item == nil {
|
||||
return "", zerr.ErrUserAPIKeyNotFound
|
||||
}
|
||||
|
||||
err = attributevalue.Unmarshal(resp.Item["Identity"], &userid)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return userid, nil
|
||||
}
|
||||
|
||||
func (dwr DBWrapper) GetUserData(ctx context.Context) (repodb.UserData, error) {
|
||||
var userData repodb.UserData
|
||||
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return userData, err
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
if userid == "" {
|
||||
// empty user is anonymous
|
||||
return userData, zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
resp, err := dwr.Client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(dwr.UserDataTablename),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"Identity": &types.AttributeValueMemberS{Value: userid},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return repodb.UserData{}, err
|
||||
}
|
||||
|
||||
if resp.Item == nil {
|
||||
return repodb.UserData{}, zerr.ErrUserDataNotFound
|
||||
}
|
||||
|
||||
err = attributevalue.Unmarshal(resp.Item["UserData"], &userData)
|
||||
if err != nil {
|
||||
return repodb.UserData{}, err
|
||||
}
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
func (dwr DBWrapper) SetUserData(ctx context.Context, userData repodb.UserData) error {
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
if userid == "" {
|
||||
// empty user is anonymous
|
||||
return zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
userAttributeValue, err := attributevalue.Marshal(userData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = dwr.Client.UpdateItem(ctx, &dynamodb.UpdateItemInput{
|
||||
ExpressionAttributeNames: map[string]string{
|
||||
"#UM": "UserData",
|
||||
"#UP": "UserData",
|
||||
},
|
||||
ExpressionAttributeValues: map[string]types.AttributeValue{
|
||||
":UserData": userAttributeValue,
|
||||
},
|
||||
Key: map[string]types.AttributeValue{
|
||||
"UserID": &types.AttributeValueMemberS{
|
||||
"Identity": &types.AttributeValueMemberS{
|
||||
Value: userid,
|
||||
},
|
||||
},
|
||||
TableName: aws.String(dwr.UserDataTablename),
|
||||
UpdateExpression: aws.String("SET #UM = :UserData"),
|
||||
UpdateExpression: aws.String("SET #UP = :UserData"),
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dwr DBWrapper) DeleteUserData(ctx context.Context) error {
|
||||
acCtx, err := localCtx.GetAccessControlContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userid := localCtx.GetUsernameFromContext(acCtx)
|
||||
if userid == "" {
|
||||
// empty user is anonymous
|
||||
return zerr.ErrUserDataNotAllowed
|
||||
}
|
||||
|
||||
_, err = dwr.Client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
|
||||
TableName: aws.String(dwr.UserDataTablename),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"Identity": &types.AttributeValueMemberS{Value: userid},
|
||||
},
|
||||
})
|
||||
|
||||
return err
|
||||
|
||||
@@ -24,6 +24,7 @@ type (
|
||||
)
|
||||
|
||||
type RepoDB interface { //nolint:interfacebloat
|
||||
UserDB
|
||||
// IncrementRepoStars adds 1 to the star count of an image
|
||||
IncrementRepoStars(repo string) error
|
||||
|
||||
@@ -111,6 +112,10 @@ type RepoDB interface { //nolint:interfacebloat
|
||||
FilterTags(ctx context.Context, filter FilterFunc,
|
||||
requestedPage PageInput) ([]RepoMetadata, map[string]ManifestMetadata, map[string]IndexData, common.PageInfo, error)
|
||||
|
||||
PatchDB() error
|
||||
}
|
||||
|
||||
type UserDB interface { //nolint:interfacebloat
|
||||
// GetStarredRepos returns starred repos and takes current user in consideration
|
||||
GetStarredRepos(ctx context.Context) ([]string, error)
|
||||
|
||||
@@ -123,7 +128,24 @@ type RepoDB interface { //nolint:interfacebloat
|
||||
// ToggleBookmarkRepo adds/removes bookmarks on repos
|
||||
ToggleBookmarkRepo(ctx context.Context, reponame string) (ToggleState, error)
|
||||
|
||||
PatchDB() error
|
||||
// UserDB profile/api key CRUD
|
||||
GetUserData(ctx context.Context) (UserData, error)
|
||||
|
||||
SetUserData(ctx context.Context, userData UserData) error
|
||||
|
||||
SetUserGroups(ctx context.Context, groups []string) error
|
||||
|
||||
GetUserGroups(ctx context.Context) ([]string, error)
|
||||
|
||||
DeleteUserData(ctx context.Context) error
|
||||
|
||||
GetUserAPIKeyInfo(hashedKey string) (identity string, err error)
|
||||
|
||||
AddUserAPIKey(ctx context.Context, hashedKey string, apiKeyDetails *APIKeyDetails) error
|
||||
|
||||
UpdateUserAPIKeyLastUsed(ctx context.Context, hashedKey string) error
|
||||
|
||||
DeleteUserAPIKey(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
type ManifestMetadata struct {
|
||||
@@ -195,12 +217,6 @@ type SignatureMetadata struct {
|
||||
LayersInfo []LayerInfo
|
||||
}
|
||||
|
||||
type UserData struct {
|
||||
// data for each user.
|
||||
StarredRepos []string
|
||||
BookmarkedRepos []string
|
||||
}
|
||||
|
||||
type SortCriteria string
|
||||
|
||||
const (
|
||||
@@ -235,3 +251,20 @@ type FilterData struct {
|
||||
IsStarred bool
|
||||
IsBookmarked bool
|
||||
}
|
||||
|
||||
type UserData struct {
|
||||
StarredRepos []string
|
||||
BookmarkedRepos []string
|
||||
Groups []string
|
||||
APIKeys map[string]APIKeyDetails
|
||||
}
|
||||
|
||||
type APIKeyDetails struct {
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
CreatorUA string `json:"creatorUa"`
|
||||
GeneratedBy string `json:"generatedBy"`
|
||||
LastUsed time.Time `json:"lastUsed"`
|
||||
Label string `json:"label"`
|
||||
Scopes []string `json:"scopes"`
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
@@ -92,6 +92,7 @@ func TestDynamoDBWrapper(t *testing.T) {
|
||||
versionTablename := "Version" + uuid.String()
|
||||
indexDataTablename := "IndexDataTable" + uuid.String()
|
||||
userDataTablename := "UserDataTable" + uuid.String()
|
||||
apiKeyTablename := "ApiKeyTable" + uuid.String()
|
||||
|
||||
Convey("DynamoDB Wrapper", t, func() {
|
||||
dynamoDBDriverParams := dynamo.DBDriverParameters{
|
||||
@@ -101,6 +102,7 @@ func TestDynamoDBWrapper(t *testing.T) {
|
||||
IndexDataTablename: indexDataTablename,
|
||||
VersionTablename: versionTablename,
|
||||
UserDataTablename: userDataTablename,
|
||||
APIKeyTablename: apiKeyTablename,
|
||||
Region: "us-east-2",
|
||||
}
|
||||
|
||||
@@ -137,6 +139,118 @@ func RunRepoDBTests(t *testing.T, repoDB repodb.RepoDB, preparationFuncs ...func
|
||||
So(err, ShouldBeNil)
|
||||
}
|
||||
|
||||
Convey("Test CRUD operations on UserData and API keys", func() {
|
||||
hashKey1 := "id"
|
||||
hashKey2 := "key"
|
||||
apiKeys := make(map[string]repodb.APIKeyDetails)
|
||||
apiKeyDetails := repodb.APIKeyDetails{
|
||||
Label: "apiKey",
|
||||
Scopes: []string{"repo"},
|
||||
UUID: hashKey1,
|
||||
}
|
||||
|
||||
apiKeys[hashKey1] = apiKeyDetails
|
||||
|
||||
userProfileSrc := repodb.UserData{
|
||||
Groups: []string{"group1", "group2"},
|
||||
APIKeys: apiKeys,
|
||||
}
|
||||
|
||||
authzCtxKey := localCtx.GetContextKey()
|
||||
|
||||
acCtx := localCtx.AccessControlContext{
|
||||
Username: "test",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
|
||||
|
||||
err := repoDB.AddUserAPIKey(ctx, hashKey1, &apiKeyDetails)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
err = repoDB.SetUserData(ctx, userProfileSrc)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
userProfile, err := repoDB.GetUserData(ctx)
|
||||
So(err, ShouldBeNil)
|
||||
So(userProfile.Groups, ShouldResemble, userProfileSrc.Groups)
|
||||
So(userProfile.APIKeys, ShouldContainKey, hashKey1)
|
||||
So(userProfile.APIKeys[hashKey1].Label, ShouldEqual, apiKeyDetails.Label)
|
||||
So(userProfile.APIKeys[hashKey1].Scopes, ShouldResemble, apiKeyDetails.Scopes)
|
||||
|
||||
lastUsed := userProfile.APIKeys[hashKey1].LastUsed
|
||||
|
||||
err = repoDB.UpdateUserAPIKeyLastUsed(ctx, hashKey1)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
userProfile, err = repoDB.GetUserData(ctx)
|
||||
So(err, ShouldBeNil)
|
||||
So(userProfile.APIKeys[hashKey1].LastUsed, ShouldHappenAfter, lastUsed)
|
||||
|
||||
userGroups, err := repoDB.GetUserGroups(ctx)
|
||||
So(err, ShouldBeNil)
|
||||
So(userGroups, ShouldResemble, userProfileSrc.Groups)
|
||||
|
||||
apiKeyDetails.UUID = hashKey2
|
||||
err = repoDB.AddUserAPIKey(ctx, hashKey2, &apiKeyDetails)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
userProfile, err = repoDB.GetUserData(ctx)
|
||||
So(err, ShouldBeNil)
|
||||
So(userProfile.Groups, ShouldResemble, userProfileSrc.Groups)
|
||||
So(userProfile.APIKeys, ShouldContainKey, hashKey2)
|
||||
So(userProfile.APIKeys[hashKey2].Label, ShouldEqual, apiKeyDetails.Label)
|
||||
So(userProfile.APIKeys[hashKey2].Scopes, ShouldResemble, apiKeyDetails.Scopes)
|
||||
|
||||
email, err := repoDB.GetUserAPIKeyInfo(hashKey2)
|
||||
So(err, ShouldBeNil)
|
||||
So(email, ShouldEqual, "test")
|
||||
|
||||
err = repoDB.DeleteUserAPIKey(ctx, hashKey1)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
userProfile, err = repoDB.GetUserData(ctx)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(userProfile.APIKeys), ShouldEqual, 1)
|
||||
So(userProfile.APIKeys, ShouldNotContainKey, hashKey1)
|
||||
|
||||
err = repoDB.DeleteUserAPIKey(ctx, hashKey2)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
userProfile, err = repoDB.GetUserData(ctx)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(userProfile.APIKeys), ShouldEqual, 0)
|
||||
So(userProfile.APIKeys, ShouldNotContainKey, hashKey2)
|
||||
|
||||
// delete non existent api key
|
||||
err = repoDB.DeleteUserAPIKey(ctx, hashKey2)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
err = repoDB.DeleteUserData(ctx)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
email, err = repoDB.GetUserAPIKeyInfo(hashKey2)
|
||||
So(err, ShouldNotBeNil)
|
||||
So(email, ShouldBeEmpty)
|
||||
|
||||
email, err = repoDB.GetUserAPIKeyInfo(hashKey1)
|
||||
So(err, ShouldNotBeNil)
|
||||
So(email, ShouldBeEmpty)
|
||||
|
||||
_, err = repoDB.GetUserData(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
|
||||
userGroups, err = repoDB.GetUserGroups(ctx)
|
||||
So(err, ShouldNotBeNil)
|
||||
So(userGroups, ShouldBeEmpty)
|
||||
|
||||
err = repoDB.SetUserGroups(ctx, userProfileSrc.Groups)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
userGroups, err = repoDB.GetUserGroups(ctx)
|
||||
So(err, ShouldBeNil)
|
||||
So(userGroups, ShouldResemble, userProfileSrc.Groups)
|
||||
})
|
||||
|
||||
Convey("Test SetManifestData and GetManifestData", func() {
|
||||
configBlob, manifestBlob, err := generateTestImage()
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
@@ -95,6 +95,9 @@ func getDynamoParams(cacheDriverConfig map[string]interface{}, log log.Logger) d
|
||||
indexDataTablename, ok := toStringIfOk(cacheDriverConfig, "indexdatatablename", log)
|
||||
allParametersOk = allParametersOk && ok
|
||||
|
||||
apiKeyTablename, ok := toStringIfOk(cacheDriverConfig, "apikeytablename", log)
|
||||
allParametersOk = allParametersOk && ok
|
||||
|
||||
versionTablename, ok := toStringIfOk(cacheDriverConfig, "versiontablename", log)
|
||||
allParametersOk = allParametersOk && ok
|
||||
|
||||
@@ -112,6 +115,7 @@ func getDynamoParams(cacheDriverConfig map[string]interface{}, log log.Logger) d
|
||||
ManifestDataTablename: manifestDataTablename,
|
||||
IndexDataTablename: indexDataTablename,
|
||||
UserDataTablename: userDataTablename,
|
||||
APIKeyTablename: apiKeyTablename,
|
||||
VersionTablename: versionTablename,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ func TestCreateDynamo(t *testing.T) {
|
||||
ManifestDataTablename: "ManifestDataTable",
|
||||
IndexDataTablename: "IndexDataTable",
|
||||
UserDataTablename: "UserDataTable",
|
||||
APIKeyTablename: "ApiKeyTable",
|
||||
VersionTablename: "Version",
|
||||
Region: "us-east-2",
|
||||
}
|
||||
|
||||
@@ -390,6 +390,7 @@ func TestParseStorageDynamoWrapper(t *testing.T) {
|
||||
ManifestDataTablename: "ManifestDataTable",
|
||||
IndexDataTablename: "IndexDataTable",
|
||||
UserDataTablename: "UserDataTable",
|
||||
APIKeyTablename: "ApiKeyTable",
|
||||
VersionTablename: "Version",
|
||||
}
|
||||
|
||||
|
||||
@@ -127,6 +127,7 @@ func TestVersioningDynamoDB(t *testing.T) {
|
||||
ManifestDataTablename: "ManifestDataTable",
|
||||
IndexDataTablename: "IndexDataTable",
|
||||
UserDataTablename: "UserDataTable",
|
||||
APIKeyTablename: "ApiKeyTable",
|
||||
VersionTablename: "Version",
|
||||
}
|
||||
|
||||
|
||||
@@ -92,3 +92,29 @@ func (acCtx *AccessControlContext) matchesRepo(globPatterns map[string]bool, rep
|
||||
|
||||
return allowed
|
||||
}
|
||||
|
||||
// request-local context key.
|
||||
var amwCtxKey = Key(1) //nolint: gochecknoglobals
|
||||
|
||||
// pointer needed for use in context.WithValue.
|
||||
func GetAuthnMiddlewareCtxKey() *Key {
|
||||
return &amwCtxKey
|
||||
}
|
||||
|
||||
type AuthnMiddlewareContext struct {
|
||||
AuthnType string
|
||||
}
|
||||
|
||||
func GetAuthnMiddlewareContext(ctx context.Context) (*AuthnMiddlewareContext, error) {
|
||||
authnMiddlewareCtxKey := GetAuthnMiddlewareCtxKey()
|
||||
if authnMiddlewareCtx := ctx.Value(authnMiddlewareCtxKey); authnMiddlewareCtx != nil {
|
||||
amCtx, ok := authnMiddlewareCtx.(AuthnMiddlewareContext)
|
||||
if !ok {
|
||||
return nil, errors.ErrBadType
|
||||
}
|
||||
|
||||
return &amCtx, nil
|
||||
}
|
||||
|
||||
return nil, nil //nolint: nilnil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"log"
|
||||
"math"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -38,6 +39,7 @@ import (
|
||||
ispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
"github.com/opencontainers/umoci"
|
||||
"github.com/phayes/freeport"
|
||||
"github.com/project-zot/mockoidc"
|
||||
"github.com/sigstore/cosign/v2/cmd/cosign/cli/generate"
|
||||
"github.com/sigstore/cosign/v2/cmd/cosign/cli/options"
|
||||
"github.com/sigstore/cosign/v2/cmd/cosign/cli/sign"
|
||||
@@ -1967,3 +1969,55 @@ func GetIndexBlobWithManifests(manifestDigests []godigest.Digest) ([]byte, error
|
||||
|
||||
return json.Marshal(indexContent)
|
||||
}
|
||||
|
||||
func MockOIDCRun() (*mockoidc.MockOIDC, error) {
|
||||
// Create a fresh RSA Private Key for token signing
|
||||
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) //nolint: gomnd
|
||||
|
||||
// Create an unstarted MockOIDC server
|
||||
mockServer, _ := mockoidc.NewServer(rsaKey)
|
||||
|
||||
// Create the net.Listener, kernel will chose a valid port
|
||||
listener, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
|
||||
bearerMiddleware := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(response http.ResponseWriter, req *http.Request) {
|
||||
// stateVal := req.Form.Get("state")
|
||||
header := req.Header.Get("Authorization")
|
||||
parts := strings.SplitN(header, " ", 2) //nolint: gomnd
|
||||
if header != "" {
|
||||
if strings.ToLower(parts[0]) == "bearer" {
|
||||
req.Header.Set("Authorization", strings.Join([]string{"Bearer", parts[1]}, " "))
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(response, req)
|
||||
})
|
||||
}
|
||||
|
||||
err := mockServer.AddMiddleware(bearerMiddleware)
|
||||
if err != nil {
|
||||
return mockServer, err
|
||||
}
|
||||
// tlsConfig can be nil if you want HTTP
|
||||
return mockServer, mockServer.Start(listener, nil)
|
||||
}
|
||||
|
||||
func CustomRedirectPolicy(noOfRedirect int) resty.RedirectPolicy {
|
||||
return resty.RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= noOfRedirect {
|
||||
return fmt.Errorf("stopped after %d redirects", noOfRedirect) //nolint: goerr113
|
||||
}
|
||||
|
||||
for key, val := range via[len(via)-1].Header {
|
||||
req.Header[key] = val
|
||||
}
|
||||
|
||||
respCookies := req.Response.Cookies()
|
||||
for _, cookie := range respCookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -95,6 +95,24 @@ type RepoDBMock struct {
|
||||
|
||||
ToggleBookmarkRepoFn func(ctx context.Context, repo string) (repodb.ToggleState, error)
|
||||
|
||||
GetUserDataFn func(ctx context.Context) (repodb.UserData, error)
|
||||
|
||||
SetUserDataFn func(ctx context.Context, userProfile repodb.UserData) error
|
||||
|
||||
SetUserGroupsFn func(ctx context.Context, groups []string) error
|
||||
|
||||
GetUserGroupsFn func(ctx context.Context) ([]string, error)
|
||||
|
||||
DeleteUserDataFn func(ctx context.Context) error
|
||||
|
||||
GetUserAPIKeyInfoFn func(hashedKey string) (string, error)
|
||||
|
||||
AddUserAPIKeyFn func(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error
|
||||
|
||||
UpdateUserAPIKeyLastUsedFn func(ctx context.Context, hashedKey string) error
|
||||
|
||||
DeleteUserAPIKeyFn func(ctx context.Context, id string) error
|
||||
|
||||
PatchDBFn func() error
|
||||
}
|
||||
|
||||
@@ -414,3 +432,75 @@ func (sdm RepoDBMock) ToggleBookmarkRepo(ctx context.Context, repo string) (repo
|
||||
|
||||
return repodb.NotChanged, nil
|
||||
}
|
||||
|
||||
func (sdm RepoDBMock) GetUserData(ctx context.Context) (repodb.UserData, error) {
|
||||
if sdm.GetUserDataFn != nil {
|
||||
return sdm.GetUserDataFn(ctx)
|
||||
}
|
||||
|
||||
return repodb.UserData{}, nil
|
||||
}
|
||||
|
||||
func (sdm RepoDBMock) SetUserData(ctx context.Context, userProfile repodb.UserData) error {
|
||||
if sdm.SetUserDataFn != nil {
|
||||
return sdm.SetUserDataFn(ctx, userProfile)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sdm RepoDBMock) SetUserGroups(ctx context.Context, groups []string) error {
|
||||
if sdm.SetUserGroupsFn != nil {
|
||||
return sdm.SetUserGroupsFn(ctx, groups)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sdm RepoDBMock) GetUserGroups(ctx context.Context) ([]string, error) {
|
||||
if sdm.GetUserGroupsFn != nil {
|
||||
return sdm.GetUserGroupsFn(ctx)
|
||||
}
|
||||
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
func (sdm RepoDBMock) DeleteUserData(ctx context.Context) error {
|
||||
if sdm.DeleteUserDataFn != nil {
|
||||
return sdm.DeleteUserDataFn(ctx)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sdm RepoDBMock) GetUserAPIKeyInfo(hashedKey string) (string, error) {
|
||||
if sdm.GetUserAPIKeyInfoFn != nil {
|
||||
return sdm.GetUserAPIKeyInfoFn(hashedKey)
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (sdm RepoDBMock) AddUserAPIKey(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error {
|
||||
if sdm.AddUserAPIKeyFn != nil {
|
||||
return sdm.AddUserAPIKeyFn(ctx, hashedKey, apiKeyDetails)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sdm RepoDBMock) UpdateUserAPIKeyLastUsed(ctx context.Context, hashedKey string) error {
|
||||
if sdm.UpdateUserAPIKeyLastUsedFn != nil {
|
||||
return sdm.UpdateUserAPIKeyLastUsedFn(ctx, hashedKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sdm RepoDBMock) DeleteUserAPIKey(ctx context.Context, id string) error {
|
||||
if sdm.DeleteUserAPIKeyFn != nil {
|
||||
return sdm.DeleteUserAPIKeyFn(ctx, id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user