feat: integrate openID auth logic and user profile management (#1381)

This change introduces OpenID authn by using providers such as Github,
Gitlab, Google and Dex.
User sessions are now used for web clients to identify
and persist an authenticated users session, thus not requiring every request to
use credentials.
Another change is apikey feature, users can create/revoke their api keys and use them
to authenticate when using cli clients such as skopeo.

eg:
login:
/auth/login?provider=github
/auth/login?provider=gitlab
and so on

logout:
/auth/logout

redirectURL:
/auth/callback/github
/auth/callback/gitlab
and so on

If network policy doesn't allow inbound connections, this callback wont work!

for more info read documentation added in this commit.

Signed-off-by: Alex Stan <alexandrustan96@yahoo.ro>
Signed-off-by: Petu Eusebiu <peusebiu@cisco.com>
Co-authored-by: Alex Stan <alexandrustan96@yahoo.ro>
This commit is contained in:
peusebiu
2023-07-07 19:27:10 +03:00
committed by GitHub
parent 5494a1b8d6
commit 17d1338af1
51 changed files with 5467 additions and 624 deletions
+678 -163
View File
@@ -3,139 +3,315 @@ package api
import (
"bufio"
"context"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"net"
"net/http"
"os"
"path"
"strconv"
"strings"
"time"
"github.com/chartmuseum/auth"
"github.com/google/go-github/v52/github"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
godigest "github.com/opencontainers/go-digest"
"github.com/zitadel/oidc/pkg/client/rp"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
githubOAuth "golang.org/x/oauth2/github"
"zotregistry.io/zot/errors"
zerr "zotregistry.io/zot/errors"
"zotregistry.io/zot/pkg/api/config"
"zotregistry.io/zot/pkg/api/constants"
apiErr "zotregistry.io/zot/pkg/api/errors"
"zotregistry.io/zot/pkg/common"
"zotregistry.io/zot/pkg/log"
localCtx "zotregistry.io/zot/pkg/requestcontext"
storageConstants "zotregistry.io/zot/pkg/storage/constants"
)
const (
bearerAuthDefaultAccessEntryType = "repository"
issuedAtOffset = 5 * time.Second
relyingPartyCookieMaxAge = 120
)
func AuthHandler(c *Controller) mux.MiddlewareFunc {
if isBearerAuthEnabled(c.Config) {
return bearerAuthHandler(c)
}
return basicAuthHandler(c)
type AuthnMiddleware struct {
credMap map[string]string
ldapClient *LDAPClient
}
func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
authorizer, err := auth.NewAuthorizer(&auth.AuthorizerOptions{
Realm: ctlr.Config.HTTP.Auth.Bearer.Realm,
Service: ctlr.Config.HTTP.Auth.Bearer.Service,
PublicKeyPath: ctlr.Config.HTTP.Auth.Bearer.Cert,
AccessEntryType: bearerAuthDefaultAccessEntryType,
EmptyDefaultNamespace: true,
})
func AuthHandler(ctlr *Controller) mux.MiddlewareFunc {
authnMiddleware := &AuthnMiddleware{}
if isBearerAuthEnabled(ctlr.Config) {
return bearerAuthHandler(ctlr)
}
return authnMiddleware.TryAuthnHandlers(ctlr)
}
func (amw *AuthnMiddleware) sessionAuthn(ctlr *Controller, next http.Handler, response http.ResponseWriter,
request *http.Request, delay int,
) {
clientHeader := request.Header.Get(constants.SessionClientHeaderName)
if clientHeader != constants.SessionClientHeaderValue {
authFail(response, request, ctlr.Config.HTTP.Realm, delay)
return
}
identity, ok := common.GetAuthUserFromRequestSession(ctlr.CookieStore, request, ctlr.Log)
if !ok {
// let the client know that this session is invalid/expired
cookie := &http.Cookie{
Name: "session",
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
HttpOnly: true,
}
http.SetCookie(response, cookie)
authFail(response, request, ctlr.Config.HTTP.Realm, delay)
return
}
ctx := getReqContextWithAuthorization(identity, []string{}, request)
groups, err := ctlr.RepoDB.GetUserGroups(ctx)
if err != nil {
ctlr.Log.Panic().Err(err).Msg("error creating bearer authorizer")
if errors.Is(err, zerr.ErrUserDataNotFound) {
ctlr.Log.Err(err).Str("identity", identity).Msg("can not find user profile in DB")
authFail(response, request, ctlr.Config.HTTP.Realm, delay)
return
}
ctlr.Log.Err(err).Str("identity", identity).Msg("can not get user profile in DB")
response.WriteHeader(http.StatusInternalServerError)
return
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Method == http.MethodOptions {
next.ServeHTTP(response, request)
response.WriteHeader(http.StatusNoContent)
ctx = getReqContextWithAuthorization(identity, groups, request)
return
}
vars := mux.Vars(request)
name := vars["name"]
// we want to bypass auth for mgmt route
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
header := request.Header.Get("Authorization")
if (header == "" || header == "Basic Og==") && isMgmtRequested {
next.ServeHTTP(response, request)
return
}
action := auth.PullAction
if m := request.Method; m != http.MethodGet && m != http.MethodHead {
action = auth.PushAction
}
permissions, err := authorizer.Authorize(header, action, name)
if err != nil {
ctlr.Log.Error().Err(err).Msg("issue parsing Authorization header")
response.Header().Set("Content-Type", "application/json")
common.WriteJSON(response, http.StatusInternalServerError, apiErr.NewErrorList(apiErr.NewError(apiErr.UNSUPPORTED)))
return
}
if !permissions.Allowed {
authFail(response, permissions.WWWAuthenticateHeader, 0)
return
}
next.ServeHTTP(response, request)
})
}
next.ServeHTTP(response, request.WithContext(ctx))
}
func noPasswdAuth(realm string, config *config.Config) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Method == http.MethodOptions {
next.ServeHTTP(response, request)
response.WriteHeader(http.StatusNoContent)
func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, response http.ResponseWriter,
request *http.Request,
) (bool, http.ResponseWriter, *http.Request, error) {
cookieStore := ctlr.CookieStore
return
}
// we want to bypass auth for mgmt route
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
// Process request
if request.Header.Get("Authorization") == "" {
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
ctx := getReqContextWithAuthorization("", []string{}, request)
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
})
// Process request
return true, response, request.WithContext(ctx), nil
}
}
identity, passphrase, err := getUsernamePasswordBasicAuth(request)
if err != nil {
ctlr.Log.Error().Err(err).Msg("failed to parse authorization header")
return false, nil, nil, nil
}
// some client tools might send Authorization: Basic Og== (decoded into ":")
// empty username and password
if identity == "" && passphrase == "" {
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
ctx := getReqContextWithAuthorization("", []string{}, request)
return true, response, request.WithContext(ctx), nil
}
}
passphraseHash, ok := amw.credMap[identity]
if ok {
// first, HTTPPassword authN (which is local)
if err := bcrypt.CompareHashAndPassword([]byte(passphraseHash), []byte(passphrase)); err == nil {
// Process request
var groups []string
if ctlr.Config.HTTP.AccessControl != nil {
ac := NewAccessController(ctlr.Config)
groups = ac.getUserGroups(identity)
}
ctx := getReqContextWithAuthorization(identity, groups, request)
// saved logged session
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
return false, response, request, err
}
if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil {
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("couldn't update user profile")
return false, response, request, err
}
ctlr.Log.Info().Str("identity", identity).Msgf("user profile successfully set")
return true, response, request.WithContext(ctx), nil
}
}
// next, LDAP if configured (network-based which can lose connectivity)
if ctlr.Config.HTTP.Auth != nil && ctlr.Config.HTTP.Auth.LDAP != nil {
ok, _, ldapgroups, err := amw.ldapClient.Authenticate(identity, passphrase)
if ok && err == nil {
// Process request
var groups []string
if ctlr.Config.HTTP.AccessControl != nil {
ac := NewAccessController(ctlr.Config)
groups = ac.getUserGroups(identity)
}
groups = append(groups, ldapgroups...)
ctx := getReqContextWithAuthorization(identity, groups, request)
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
return false, response, request, err
}
if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil {
ctlr.Log.Error().Err(err).Str("identity", identity).Msg("couldn't update user profile")
return false, response, request, err
}
return true, response, request.WithContext(ctx), nil
}
}
// last try API keys
if isAPIKeyEnabled(ctlr.Config) {
apiKey := passphrase
if !strings.HasPrefix(apiKey, constants.APIKeysPrefix) {
ctlr.Log.Error().Msg("api token has invalid format")
return false, nil, nil, nil
}
trimmedAPIKey := strings.TrimPrefix(apiKey, constants.APIKeysPrefix)
hashedKey := hashUUID(trimmedAPIKey)
storedIdentity, err := ctlr.RepoDB.GetUserAPIKeyInfo(hashedKey)
if err != nil {
if errors.Is(err, zerr.ErrUserAPIKeyNotFound) {
ctlr.Log.Info().Err(err).Msgf("can not find any user info for hashed key %s in DB", hashedKey)
return false, nil, nil, nil
}
ctlr.Log.Error().Err(err).Msgf("can not get user info for hashed key %s in DB", hashedKey)
return false, nil, nil, err
}
if storedIdentity == identity {
ctx := getReqContextWithAuthorization(identity, []string{}, request)
err := ctlr.RepoDB.UpdateUserAPIKeyLastUsed(ctx, hashedKey)
if err != nil {
ctlr.Log.Err(err).Str("identity", identity).Msg("can not update user profile in DB")
return false, nil, nil, err
}
groups, err := ctlr.RepoDB.GetUserGroups(ctx)
if err != nil {
ctlr.Log.Err(err).Str("identity", identity).Msg("can not get user's groups in DB")
return false, nil, nil, err
}
ctx = getReqContextWithAuthorization(identity, groups, request)
return true, response, request.WithContext(ctx), nil
}
}
return false, nil, nil, nil
}
//nolint:gocyclo // we use closure making this a complex subroutine
func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
realm := ctlr.Config.HTTP.Realm
if realm == "" {
realm = "Authorization Required"
}
realm = "Basic realm=" + strconv.Quote(realm)
func (amw *AuthnMiddleware) TryAuthnHandlers(ctlr *Controller) mux.MiddlewareFunc { //nolint: gocyclo
// no password based authN, if neither LDAP nor HTTP BASIC is enabled
if ctlr.Config.HTTP.Auth == nil ||
(ctlr.Config.HTTP.Auth.HTPasswd.Path == "" && ctlr.Config.HTTP.Auth.LDAP == nil) {
return noPasswdAuth(realm, ctlr.Config)
(ctlr.Config.HTTP.Auth.HTPasswd.Path == "" && ctlr.Config.HTTP.Auth.LDAP == nil &&
ctlr.Config.HTTP.Auth.OpenID == nil) {
return noPasswdAuth(ctlr.Config)
}
credMap := make(map[string]string)
amw.credMap = make(map[string]string)
delay := ctlr.Config.HTTP.Auth.FailDelay
var ldapClient *LDAPClient
// setup sessions cookie store used to preserve logged in user in web sessions
if isAuthnEnabled(ctlr.Config) || isOpenIDAuthEnabled(ctlr.Config) {
// To store custom types in our cookies,
// we must first register them using gob.Register
gob.Register(map[string]interface{}{})
cookieStoreHashKey := securecookie.GenerateRandomKey(64)
if cookieStoreHashKey == nil {
panic(zerr.ErrHashKeyNotCreated)
}
// if storage is filesystem then use zot's rootDir to store sessions
if ctlr.Config.Storage.StorageDriver == nil {
sessionsDir := path.Join(ctlr.Config.Storage.RootDirectory, "_sessions")
if err := os.MkdirAll(sessionsDir, storageConstants.DefaultDirPerms); err != nil {
panic(err)
}
cookieStore := sessions.NewFilesystemStore(sessionsDir, cookieStoreHashKey)
cookieStore.MaxAge(cookiesMaxAge)
ctlr.CookieStore = cookieStore
} else {
cookieStore := sessions.NewCookieStore(cookieStoreHashKey)
cookieStore.MaxAge(cookiesMaxAge)
ctlr.CookieStore = cookieStore
}
}
// ldap and htpasswd based authN
if ctlr.Config.HTTP.Auth != nil {
if ctlr.Config.HTTP.Auth.LDAP != nil {
ldapConfig := ctlr.Config.HTTP.Auth.LDAP
ldapClient = &LDAPClient{
amw.ldapClient = &LDAPClient{
Host: ldapConfig.Address,
Port: ldapConfig.Port,
UseSSL: !ldapConfig.Insecure,
@@ -160,18 +336,18 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
panic(errors.ErrBadCACert)
panic(zerr.ErrBadCACert)
}
ldapClient.ClientCAs = caCertPool
amw.ldapClient.ClientCAs = caCertPool
} else {
// default to system cert pool
caCertPool, err := x509.SystemCertPool()
if err != nil {
panic(errors.ErrBadCACert)
panic(zerr.ErrBadCACert)
}
ldapClient.ClientCAs = caCertPool
amw.ldapClient.ClientCAs = caCertPool
}
}
@@ -188,12 +364,27 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
line := scanner.Text()
if strings.Contains(line, ":") {
tokens := strings.Split(scanner.Text(), ":")
credMap[tokens[0]] = tokens[1]
amw.credMap[tokens[0]] = tokens[1]
}
}
}
}
// openid based authN
if ctlr.Config.HTTP.Auth.OpenID != nil {
ctlr.RelyingParties = make(map[string]rp.RelyingParty)
for provider := range ctlr.Config.HTTP.Auth.OpenID.Providers {
if IsOpenIDSupported(provider) {
rp := NewRelyingPartyOIDC(ctlr.Config, provider)
ctlr.RelyingParties[provider] = rp
} else if IsOauth2Supported(provider) {
rp := NewRelyingPartyGithub(ctlr.Config, provider)
ctlr.RelyingParties[provider] = rp
}
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Method == http.MethodOptions {
@@ -203,84 +394,231 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
return
}
// we want to bypass auth for mgmt route
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
if request.Header.Get("Authorization") == "" {
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
// Process request
ctx := getReqContextWithAuthorization("", []string{}, request)
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
return
}
}
username, passphrase, err := getUsernamePasswordBasicAuth(request)
//nolint: contextcheck
authenticated, cloneResp, cloneReq, err := amw.basicAuthn(ctlr, response, request)
if err != nil {
ctlr.Log.Error().Err(err).Msg("failed to parse authorization header")
authFail(response, realm, delay)
response.WriteHeader(http.StatusInternalServerError)
return
}
// some client tools might send Authorization: Basic Og== (decoded into ":")
// empty username and password
if username == "" && passphrase == "" {
if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested {
// Process request
ctx := getReqContextWithAuthorization("", []string{}, request)
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
if authenticated && cloneResp != nil && cloneReq != nil {
next.ServeHTTP(cloneResp, cloneReq)
return
}
return
}
// first, HTTPPassword authN (which is local)
passphraseHash, ok := credMap[username]
if ok {
if err := bcrypt.CompareHashAndPassword([]byte(passphraseHash), []byte(passphrase)); err == nil {
// Process request
var userGroups []string
if ctlr.Config.HTTP.AccessControl != nil {
ac := NewAccessController(ctlr.Config)
userGroups = ac.getUserGroups(username)
}
ctx := getReqContextWithAuthorization(username, userGroups, request)
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
return
}
}
// next, LDAP if configured (network-based which can lose connectivity)
if ctlr.Config.HTTP.Auth != nil && ctlr.Config.HTTP.Auth.LDAP != nil {
ok, _, ldapgroups, err := ldapClient.Authenticate(username, passphrase)
if ok && err == nil {
// Process request
var userGroups []string
if ctlr.Config.HTTP.AccessControl != nil {
ac := NewAccessController(ctlr.Config)
userGroups = ac.getUserGroups(username)
}
userGroups = append(userGroups, ldapgroups...)
ctx := getReqContextWithAuthorization(username, userGroups, request)
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
return
}
}
authFail(response, realm, delay)
//nolint: contextcheck
amw.sessionAuthn(ctlr, next, response, request, delay)
})
}
}
func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
authorizer, err := auth.NewAuthorizer(&auth.AuthorizerOptions{
Realm: ctlr.Config.HTTP.Auth.Bearer.Realm,
Service: ctlr.Config.HTTP.Auth.Bearer.Service,
PublicKeyPath: ctlr.Config.HTTP.Auth.Bearer.Cert,
AccessEntryType: bearerAuthDefaultAccessEntryType,
EmptyDefaultNamespace: true,
})
if err != nil {
ctlr.Log.Panic().Err(err).Msg("error creating bearer authorizer")
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Method == http.MethodOptions {
next.ServeHTTP(response, request)
response.WriteHeader(http.StatusNoContent)
return
}
acCtrlr := NewAccessController(ctlr.Config)
vars := mux.Vars(request)
name := vars["name"]
// we want to bypass auth for mgmt route
isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix
header := request.Header.Get("Authorization")
if (header == "" || header == "Basic Og==") && isMgmtRequested {
next.ServeHTTP(response, request)
return
}
action := auth.PullAction
if m := request.Method; m != http.MethodGet && m != http.MethodHead {
action = auth.PushAction
}
permissions, err := authorizer.Authorize(header, action, name)
if err != nil {
ctlr.Log.Error().Err(err).Msg("issue parsing Authorization header")
response.Header().Set("Content-Type", "application/json")
common.WriteJSON(response, http.StatusInternalServerError, apiErr.NewErrorList(apiErr.NewError(apiErr.UNSUPPORTED)))
return
}
if !permissions.Allowed {
response.Header().Set("Content-Type", "application/json")
response.Header().Set("WWW-Authenticate", permissions.WWWAuthenticateHeader)
common.WriteJSON(response, http.StatusUnauthorized,
apiErr.NewErrorList(apiErr.NewError(apiErr.UNAUTHORIZED)))
return
}
amCtx := acCtrlr.getAuthnMiddlewareContext(BEARER, request)
next.ServeHTTP(response, request.WithContext(amCtx)) //nolint:contextcheck
})
}
}
func noPasswdAuth(config *config.Config) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Method == http.MethodOptions {
next.ServeHTTP(response, request)
response.WriteHeader(http.StatusNoContent)
return
}
ctx := getReqContextWithAuthorization("", []string{}, request)
// Process request
next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck
})
}
}
func (rh *RouteHandler) AuthURLHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
callbackUI := query.Get(constants.CallbackUIQueryParam)
provider := query.Get("provider")
client, ok := rh.c.RelyingParties[provider]
if !ok {
http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
response.WriteHeader(http.StatusBadRequest)
})(w, r)
}
/* save cookie containing state to later verify it and
callback ui where we will redirect after openid/oauth2 logic is completed*/
session, _ := rh.c.CookieStore.Get(r, "statecookie")
session.Options.Secure = true
session.Options.HttpOnly = true
session.Options.SameSite = http.SameSiteDefaultMode
session.Options.Path = constants.CallbackBasePath
state := uuid.New().String()
session.Values["state"] = state
session.Values["callback"] = callbackUI
// let the session set its own id
err := session.Save(r, w)
if err != nil {
rh.c.Log.Error().Err(err).Msg("unable to save http session")
w.WriteHeader(http.StatusInternalServerError)
return
}
stateFunc := func() string {
return state
}
rp.AuthURLHandler(stateFunc, client)(w, r)
}
}
func NewRelyingPartyOIDC(config *config.Config, provider string) rp.RelyingParty {
issuer, clientID, clientSecret, redirectURI, scopes, options := getRelyingPartyArgs(config, provider)
relyingParty, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...)
if err != nil {
panic(err)
}
return relyingParty
}
func NewRelyingPartyGithub(config *config.Config, provider string) rp.RelyingParty {
_, clientID, clientSecret, redirectURI, scopes, options := getRelyingPartyArgs(config, provider)
rpConfig := &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURI,
Scopes: scopes,
Endpoint: githubOAuth.Endpoint,
}
relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...)
if err != nil {
panic(err)
}
return relyingParty
}
func getRelyingPartyArgs(config *config.Config, provider string) (
string, string, string, string, []string, []rp.Option,
) {
if _, ok := config.HTTP.Auth.OpenID.Providers[provider]; !ok {
panic(zerr.ErrOpenIDProviderDoesNotExist)
}
scheme := "http"
if config.HTTP.TLS != nil {
scheme = "https"
}
clientID := config.HTTP.Auth.OpenID.Providers[provider].ClientID
clientSecret := config.HTTP.Auth.OpenID.Providers[provider].ClientSecret
scopes := config.HTTP.Auth.OpenID.Providers[provider].Scopes
// openid scope must be the first one in list
if !common.Contains(scopes, oidc.ScopeOpenID) && IsOpenIDSupported(provider) {
scopes = append([]string{oidc.ScopeOpenID}, scopes...)
}
port := config.HTTP.Port
issuer := config.HTTP.Auth.OpenID.Providers[provider].Issuer
keyPath := config.HTTP.Auth.OpenID.Providers[provider].KeyPath
baseURL := net.JoinHostPort(config.HTTP.Address, port)
redirectURI := fmt.Sprintf("%s://%s%s", scheme, baseURL, constants.CallbackBasePath+fmt.Sprintf("/%s", provider))
options := []rp.Option{
rp.WithVerifierOpts(rp.WithIssuedAtOffset(issuedAtOffset)),
}
key := securecookie.GenerateRandomKey(32) //nolint: gomnd
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithMaxAge(relyingPartyCookieMaxAge))
options = append(options, rp.WithCookieHandler(cookieHandler))
if clientSecret == "" {
options = append(options, rp.WithPKCE(cookieHandler))
}
if keyPath != "" {
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
}
return issuer, clientID, clientSecret, redirectURI, scopes, options
}
func getReqContextWithAuthorization(username string, groups []string, request *http.Request) context.Context {
acCtx := localCtx.AccessControlContext{
Username: username,
@@ -314,9 +652,71 @@ func isBearerAuthEnabled(config *config.Config) bool {
return false
}
func authFail(w http.ResponseWriter, realm string, delay int) {
func isOpenIDAuthEnabled(config *config.Config) bool {
if config.HTTP.Auth != nil &&
config.HTTP.Auth.OpenID != nil {
for provider := range config.HTTP.Auth.OpenID.Providers {
if isOpenIDAuthProviderEnabled(config, provider) {
return true
}
}
}
return false
}
func isAPIKeyEnabled(config *config.Config) bool {
if config.Extensions != nil && config.Extensions.APIKey != nil &&
*config.Extensions.APIKey.Enable {
return true
}
return false
}
func isOpenIDAuthProviderEnabled(config *config.Config, provider string) bool {
if providerConfig, ok := config.HTTP.Auth.OpenID.Providers[provider]; ok {
if IsOpenIDSupported(provider) {
if providerConfig.ClientID != "" || providerConfig.Issuer != "" ||
len(providerConfig.Scopes) > 0 {
return true
}
} else if IsOauth2Supported(provider) {
if providerConfig.ClientID != "" || len(providerConfig.Scopes) > 0 {
return true
}
}
}
return false
}
func IsOpenIDSupported(provider string) bool {
supported := []string{"google", "gitlab", "dex"}
return common.Contains(supported, provider)
}
func IsOauth2Supported(provider string) bool {
supported := []string{"github"}
return common.Contains(supported, provider)
}
func authFail(w http.ResponseWriter, r *http.Request, realm string, delay int) {
time.Sleep(time.Duration(delay) * time.Second)
w.Header().Set("WWW-Authenticate", realm)
// don't send auth headers if request is coming from UI
if r.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue {
if realm == "" {
realm = "Authorization Required"
}
realm = "Basic realm=" + strconv.Quote(realm)
w.Header().Set("WWW-Authenticate", realm)
}
w.Header().Set("Content-Type", "application/json")
common.WriteJSON(w, http.StatusUnauthorized, apiErr.NewErrorList(apiErr.NewError(apiErr.UNAUTHORIZED)))
}
@@ -325,12 +725,12 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error)
basicAuth := request.Header.Get("Authorization")
if basicAuth == "" {
return "", "", errors.ErrParsingAuthHeader
return "", "", zerr.ErrParsingAuthHeader
}
splitStr := strings.SplitN(basicAuth, " ", 2) //nolint:gomnd
splitStr := strings.SplitN(basicAuth, " ", 2) //nolint: gomnd
if len(splitStr) != 2 || strings.ToLower(splitStr[0]) != "basic" {
return "", "", errors.ErrParsingAuthHeader
return "", "", zerr.ErrParsingAuthHeader
}
decodedStr, err := base64.StdEncoding.DecodeString(splitStr[1])
@@ -338,9 +738,9 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error)
return "", "", err
}
pair := strings.SplitN(string(decodedStr), ":", 2) //nolint:gomnd
if len(pair) != 2 { //nolint:gomnd
return "", "", errors.ErrParsingAuthHeader
pair := strings.SplitN(string(decodedStr), ":", 2) //nolint: gomnd
if len(pair) != 2 { //nolint: gomnd
return "", "", zerr.ErrParsingAuthHeader
}
username := pair[0]
@@ -348,3 +748,118 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error)
return username, passphrase, nil
}
func GetGithubUserInfo(ctx context.Context, client *github.Client, log log.Logger) (string, []string, error) {
var primaryEmail string
userEmails, _, err := client.Users.ListEmails(ctx, nil)
if err != nil {
log.Error().Msg("couldn't set user record for empty email value")
return "", []string{}, err
}
if len(userEmails) != 0 {
for _, email := range userEmails { // should have at least one primary email, if any
if email.GetPrimary() { // check if it's primary email
primaryEmail = email.GetEmail()
break
}
}
}
orgs, _, err := client.Organizations.List(ctx, "", nil)
if err != nil {
log.Error().Msg("couldn't set user record for empty email value")
return "", []string{}, err
}
groups := []string{}
for _, org := range orgs {
groups = append(groups, *org.Login)
}
return primaryEmail, groups, nil
}
func saveUserLoggedSession(cookieStore sessions.Store, response http.ResponseWriter,
request *http.Request, identity string, log log.Logger,
) error {
session, _ := cookieStore.Get(request, "session")
session.Options.Secure = true
session.Options.HttpOnly = true
session.Options.SameSite = http.SameSiteDefaultMode
session.Values["authStatus"] = true
session.Values["user"] = identity
// let the session set its own id
err := session.Save(request, response)
if err != nil {
log.Error().Err(err).Str("identity", identity).Msg("unable to save http session")
return err
}
userInfoCookie := sessions.NewCookie("user", identity, &sessions.Options{
Secure: true,
HttpOnly: false,
MaxAge: cookiesMaxAge,
SameSite: http.SameSiteDefaultMode,
Path: "/",
})
http.SetCookie(response, userInfoCookie)
return nil
}
// OAuth2Callback is the callback logic where openid/oauth2 will redirect back to our app.
func OAuth2Callback(ctlr *Controller, w http.ResponseWriter, r *http.Request, state, email string,
groups []string,
) (string, error) {
stateCookie, _ := ctlr.CookieStore.Get(r, "statecookie")
stateOrigin, ok := stateCookie.Values["state"].(string)
if !ok {
ctlr.Log.Error().Err(zerr.ErrInvalidStateCookie).Msg("openID: unable to get 'state' cookie from request")
return "", zerr.ErrInvalidStateCookie
}
if stateOrigin != state {
ctlr.Log.Error().Err(zerr.ErrInvalidStateCookie).Msg("openID: 'state' cookie differs from the actual one")
return "", zerr.ErrInvalidStateCookie
}
ctx := getReqContextWithAuthorization(email, groups, r)
// if this line has been reached, then a new session should be created
// if the `session` key is already on the cookie, it's not a valid one
if err := saveUserLoggedSession(ctlr.CookieStore, w, r, email, ctlr.Log); err != nil {
return "", err
}
if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil {
ctlr.Log.Error().Err(err).Str("identity", email).Msg("couldn't update the user profile")
return "", err
}
ctlr.Log.Info().Msgf("user profile set successfully for email %s", email)
// redirect to UI
callbackUI, _ := stateCookie.Values["callback"].(string)
return callbackUI, nil
}
func hashUUID(uuid string) string {
digester := sha256.New()
digester.Write([]byte(uuid))
return godigest.NewDigestFromEncoded(godigest.SHA256, fmt.Sprintf("%x", digester.Sum(nil))).Encoded()
}
+45 -11
View File
@@ -21,6 +21,9 @@ const (
Delete = "delete"
// behaviour actions.
DetectManifestCollision = "detectManifestCollision"
BASIC = "Basic"
BEARER = "Bearer"
OPENID = "OpenID"
)
// AccessController authorizes users to act on resources.
@@ -29,10 +32,17 @@ type AccessController struct {
Log log.Logger
}
func NewAccessController(config *config.Config) *AccessController {
func NewAccessController(conf *config.Config) *AccessController {
if conf.HTTP.AccessControl == nil {
return &AccessController{
Config: &config.AccessControlConfig{},
Log: log.NewLogger(conf.Log.Level, conf.Log.Output),
}
}
return &AccessController{
Config: config.HTTP.AccessControl,
Log: log.NewLogger(config.Log.Level, config.Log.Output),
Config: conf.HTTP.AccessControl,
Log: log.NewLogger(conf.Log.Level, conf.Log.Output),
}
}
@@ -171,6 +181,18 @@ func (ac *AccessController) getContext(acCtx *localCtx.AccessControlContext, req
return ctx
}
// getAuthnMiddlewareContext builds ac context(allowed to read repos and if user is admin) and returns it.
func (ac *AccessController) getAuthnMiddlewareContext(authnType string, request *http.Request) context.Context {
amwCtx := localCtx.AuthnMiddlewareContext{
AuthnType: authnType,
}
amwCtxKey := localCtx.GetAuthnMiddlewareCtxKey()
ctx := context.WithValue(request.Context(), amwCtxKey, amwCtx)
return ctx
}
// isPermitted returns true if username can do action on a repository policy.
func (ac *AccessController) isPermitted(userGroups []string, username, action string,
policyGroup config.PolicyGroup,
@@ -231,6 +253,14 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
return
}
// request comes from bearer authn, bypass it
authnMwCtx, err := localCtx.GetAuthnMiddlewareContext(request.Context())
if err != nil || (authnMwCtx != nil && authnMwCtx.AuthnType == BEARER) {
next.ServeHTTP(response, request)
return
}
// bypass authz for /v2/ route
if request.RequestURI == "/v2/" {
next.ServeHTTP(response, request)
@@ -242,8 +272,6 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
var identity string
var err error
// anonymous context
acCtx := &localCtx.AccessControlContext{}
@@ -252,7 +280,7 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
// get access control context made in authn.go if authn is enabled
acCtx, err = localCtx.GetAccessControlContext(request.Context())
if err != nil { // should never happen
authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
return
}
@@ -272,7 +300,7 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
// if we still don't have an identity
if identity == "" {
acCtrlr.Log.Info().Msg("couldn't get identity from TLS certificate")
authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
return
}
@@ -298,6 +326,14 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
return
}
// request comes from bearer authn, bypass it
authnMwCtx, err := localCtx.GetAuthnMiddlewareContext(request.Context())
if err != nil || (authnMwCtx != nil && authnMwCtx.AuthnType == BEARER) {
next.ServeHTTP(response, request)
return
}
vars := mux.Vars(request)
resource := vars["name"]
reference, ok := vars["reference"]
@@ -306,12 +342,10 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
var identity string
var err error
// get acCtx built in authn and previous authz middlewares
acCtx, err := localCtx.GetAccessControlContext(request.Context())
if err != nil { // should never happen
authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
return
}
@@ -344,7 +378,7 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc {
can := acCtrlr.can(request.Context(), identity, action, resource) //nolint:contextcheck
if !can {
common.AuthzFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
common.AuthzFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay)
} else {
next.ServeHTTP(response, request) //nolint:contextcheck
}
+14
View File
@@ -45,6 +45,7 @@ type AuthConfig struct {
HTPasswd AuthHTPasswd
LDAP *LDAPConfig
Bearer *BearerConfig
OpenID *OpenIDConfig
}
type BearerConfig struct {
@@ -53,6 +54,18 @@ type BearerConfig struct {
Cert string
}
type OpenIDConfig struct {
Providers map[string]OpenIDProviderConfig
}
type OpenIDProviderConfig struct {
ClientID string
ClientSecret string
KeyPath string
Issuer string
Scopes []string
}
type MethodRatelimitConfig struct {
Method string
Rate int
@@ -63,6 +76,7 @@ type RatelimitConfig struct {
Methods []MethodRatelimitConfig `mapstructure:",omitempty"`
}
//nolint:maligned
type HTTPConfig struct {
Address string
Port string
+7
View File
@@ -12,4 +12,11 @@ const (
DefaultMediaType = "application/json"
BinaryMediaType = "application/octet-stream"
DefaultMetricsExtensionRoute = "/metrics"
CallbackBasePath = "/auth/callback"
LoginPath = "/auth/login"
LogoutPath = "/auth/logout"
SessionClientHeaderName = "X-ZOT-API-CLIENT"
SessionClientHeaderValue = "zot-ui"
APIKeysPrefix = "zak_"
CallbackUIQueryParam = "callback_ui"
)
+3
View File
@@ -18,4 +18,7 @@ const (
ExtUserPreferences = "/userprefs"
ExtUserPreferencesPrefix = ExtPrefix + ExtUserPreferences
FullUserPreferencesPrefix = RoutePrefix + ExtUserPreferencesPrefix
ExtAPIKey = "/apikey"
ExtAPIKeyPrefix = ExtPrefix + ExtAPIKey //nolint: gosec
FullAPIKeyPrefix = RoutePrefix + ExtAPIKeyPrefix
)
+8 -1
View File
@@ -16,6 +16,8 @@ import (
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/zitadel/oidc/pkg/client/rp"
"zotregistry.io/zot/errors"
"zotregistry.io/zot/pkg/api/config"
@@ -31,6 +33,7 @@ import (
const (
idleTimeout = 120 * time.Second
readHeaderTimeout = 5 * time.Second
cookiesMaxAge = 86400 // seconds
)
type Controller struct {
@@ -44,6 +47,8 @@ type Controller struct {
Metrics monitoring.MetricServer
CveInfo ext.CveInfo
SyncOnDemand SyncOnDemand
RelyingParties map[string]rp.RelyingParty
CookieStore sessions.Store
// runtime params
chosenPort int // kernel-chosen port
}
@@ -254,7 +259,9 @@ func (c *Controller) InitImageStore() error {
}
func (c *Controller) InitRepoDB(reloadCtx context.Context) error {
if c.Config.Extensions != nil && c.Config.Extensions.Search != nil && *c.Config.Extensions.Search.Enable {
// init repoDB if search is enabled or authn enabled (need to store user profiles) or apikey ext is enabled
if (c.Config.Extensions != nil && c.Config.Extensions.Search != nil && *c.Config.Extensions.Search.Enable) ||
isAuthnEnabled(c.Config) || isOpenIDAuthEnabled(c.Config) || isAPIKeyEnabled(c.Config) {
driver, err := repodbfactory.New(c.Config.Storage.StorageConfig, c.Log) //nolint:contextcheck
if err != nil {
return err
+1213 -26
View File
File diff suppressed because it is too large Load Diff
+164 -18
View File
@@ -20,11 +20,14 @@ import (
"strconv"
"strings"
"github.com/google/go-github/v52/github"
"github.com/gorilla/mux"
"github.com/opencontainers/distribution-spec/specs-go/v1/extensions"
godigest "github.com/opencontainers/go-digest"
ispec "github.com/opencontainers/image-spec/specs-go/v1"
artifactspec "github.com/oras-project/artifacts-spec/specs-go/v1"
"github.com/zitadel/oidc/pkg/client/rp"
"github.com/zitadel/oidc/pkg/oidc"
zerr "zotregistry.io/zot/errors"
"zotregistry.io/zot/pkg/api/constants"
@@ -55,13 +58,38 @@ func NewRouteHandler(c *Controller) *RouteHandler {
}
func (rh *RouteHandler) SetupRoutes() {
// first get Auth middleware in order to first setup openid/ldap/htpasswd, before oidc provider routes are setup
authHandler := AuthHandler(rh.c)
applyCORSHeaders := getCORSHeadersHandler(rh.c.Config.HTTP.AllowOrigin)
if isOpenIDAuthEnabled(rh.c.Config) {
// login path for openID
rh.c.Router.HandleFunc(constants.LoginPath, rh.AuthURLHandler())
// logout path for openID
rh.c.Router.HandleFunc(constants.LogoutPath, applyCORSHeaders(rh.Logout)).
Methods(zcommon.AllowedMethods("POST")...)
// callback path for openID
for provider, relyingParty := range rh.c.RelyingParties {
if IsOauth2Supported(provider) {
rh.c.Router.HandleFunc(constants.CallbackBasePath+fmt.Sprintf("/%s", provider),
rp.CodeExchangeHandler(rh.GithubCodeExchangeCallback(), relyingParty))
} else if IsOpenIDSupported(provider) {
rh.c.Router.HandleFunc(constants.CallbackBasePath+fmt.Sprintf("/%s", provider),
rp.CodeExchangeHandler(rp.UserinfoCallback(rh.OpenIDCodeExchangeCallback()), relyingParty))
}
}
}
prefixedRouter := rh.c.Router.PathPrefix(constants.RoutePrefix).Subrouter()
prefixedRouter.Use(AuthHandler(rh.c))
prefixedRouter.Use(authHandler)
prefixedDistSpecRouter := prefixedRouter.NewRoute().Subrouter()
// authz is being enabled if AccessControl is specified
// if Authn is not present AccessControl will have only default policies
if rh.c.Config.HTTP.AccessControl != nil && !isBearerAuthEnabled(rh.c.Config) {
if rh.c.Config.HTTP.AccessControl != nil {
if isAuthnEnabled(rh.c.Config) {
rh.c.Log.Info().Msg("access control is being enabled")
} else {
@@ -72,8 +100,6 @@ func (rh *RouteHandler) SetupRoutes() {
prefixedDistSpecRouter.Use(DistSpecAuthzHandler(rh.c))
}
applyCORSHeaders := getCORSHeadersHandler(rh.c.Config.HTTP.AllowOrigin)
// https://github.com/opencontainers/distribution-spec/blob/main/spec.md#endpoints
{
prefixedDistSpecRouter.HandleFunc(fmt.Sprintf("/{name:%s}/tags/list", zreg.NameRegexp.String()),
@@ -118,7 +144,7 @@ func (rh *RouteHandler) SetupRoutes() {
constants.ArtifactSpecRoutePrefix, zreg.NameRegexp.String()), rh.GetOrasReferrers).Methods("GET")
// swagger
debug.SetupSwaggerRoutes(rh.c.Config, rh.c.Router, AuthHandler(rh.c), rh.c.Log)
debug.SetupSwaggerRoutes(rh.c.Config, rh.c.Router, authHandler, rh.c.Log)
// Setup Extensions Routes
if rh.c.Config != nil {
@@ -135,8 +161,8 @@ func (rh *RouteHandler) SetupRoutes() {
rh.c.Log)
ext.SetupUserPreferencesRoutes(rh.c.Config, prefixedExtensionsRouter, rh.c.StoreController, rh.c.RepoDB,
rh.c.CveInfo, rh.c.Log)
ext.SetupMetricsRoutes(rh.c.Config, rh.c.Router, rh.c.StoreController, AuthHandler(rh.c), rh.c.Log)
ext.SetupAPIKeyRoutes(rh.c.Config, prefixedExtensionsRouter, rh.c.RepoDB, rh.c.CookieStore, rh.c.Log)
ext.SetupMetricsRoutes(rh.c.Config, rh.c.Router, rh.c.StoreController, authHandler, rh.c.Log)
gqlPlayground.SetupGQLPlaygroundRoutes(rh.c.Config, prefixedRouter, rh.c.StoreController, rh.c.Log)
@@ -185,7 +211,8 @@ func addCORSHeaders(allowOrigin string, response http.ResponseWriter) {
// @Success 200 {string} string "ok".
func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
@@ -195,10 +222,13 @@ func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, reques
// NOTE: compatibility workaround - return this header in "allowed-read" mode to allow for clients to
// work correctly
if rh.c.Config.HTTP.Auth != nil {
if rh.c.Config.HTTP.Auth.Bearer != nil {
response.Header().Set("WWW-Authenticate", fmt.Sprintf("bearer realm=%s", rh.c.Config.HTTP.Auth.Bearer.Realm))
} else {
response.Header().Set("WWW-Authenticate", fmt.Sprintf("basic realm=%s", rh.c.Config.HTTP.Realm))
// don't send auth headers if request is coming from UI
if request.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue {
if rh.c.Config.HTTP.Auth.Bearer != nil {
response.Header().Set("WWW-Authenticate", fmt.Sprintf("bearer realm=%s", rh.c.Config.HTTP.Auth.Bearer.Realm))
} else {
response.Header().Set("WWW-Authenticate", fmt.Sprintf("basic realm=%s", rh.c.Config.HTTP.Realm))
}
}
}
@@ -224,7 +254,8 @@ type ImageTags struct {
// @Failure 400 {string} string "bad request".
func (rh *RouteHandler) ListTags(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
@@ -355,7 +386,8 @@ func (rh *RouteHandler) ListTags(response http.ResponseWriter, request *http.Req
// @Failure 500 {string} string "internal server error".
func (rh *RouteHandler) CheckManifest(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
@@ -427,7 +459,8 @@ type ExtensionList struct {
// @Router /v2/{name}/manifests/{reference} [get].
func (rh *RouteHandler) GetManifest(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
@@ -527,7 +560,8 @@ func getReferrers(routeHandler *RouteHandler,
// @Router /v2/{name}/referrers/{digest} [get].
func (rh *RouteHandler) GetReferrers(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
@@ -1576,7 +1610,8 @@ type RepositoryList struct {
// @Router /v2/_catalog [get].
func (rh *RouteHandler) ListRepositories(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
@@ -1642,7 +1677,8 @@ func (rh *RouteHandler) ListRepositories(response http.ResponseWriter, request *
// @Router /v2/_oci/ext/discover [get].
func (rh *RouteHandler) ListExtensions(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
w.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
w.Header().Set("Access-Control-Allow-Credentials", "true")
if r.Method == http.MethodOptions {
return
@@ -1653,6 +1689,116 @@ func (rh *RouteHandler) ListExtensions(w http.ResponseWriter, r *http.Request) {
zcommon.WriteJSON(w, http.StatusOK, extensionList)
}
// The following routes are specific to zot and NOT part of the OCI dist-spec
// Logout godoc
// @Summary Logout by removing current session
// @Description Logout by removing current session
// @Router /openid/auth/logout [post]
// @Accept json
// @Produce json
// @Success 200 {string} string "ok".
// @Failure 500 {string} string "internal server error".
func (rh *RouteHandler) Logout(response http.ResponseWriter, request *http.Request) {
response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS")
response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
response.Header().Set("Access-Control-Allow-Credentials", "true")
if request.Method == http.MethodOptions {
return
}
session, _ := rh.c.CookieStore.Get(request, "session")
session.Options.MaxAge = -1
err := session.Save(request, response)
if err != nil {
response.WriteHeader(http.StatusInternalServerError)
return
}
response.WriteHeader(http.StatusOK)
}
// github Oauth2 CodeExchange callback.
func (rh *RouteHandler) GithubCodeExchangeCallback() rp.CodeExchangeCallback {
return func(w http.ResponseWriter, r *http.Request,
tokens *oidc.Tokens, state string, relyingParty rp.RelyingParty,
) {
ctx := r.Context()
client := github.NewClient(relyingParty.OAuthConfig().Client(ctx, tokens.Token))
email, groups, err := GetGithubUserInfo(ctx, client, rh.c.Log)
if email == "" || err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
callbackUI, err := OAuth2Callback(rh.c, w, r, state, email, groups) //nolint: contextcheck
if err != nil {
if errors.Is(err, zerr.ErrInvalidStateCookie) {
w.WriteHeader(http.StatusUnauthorized)
}
w.WriteHeader(http.StatusInternalServerError)
}
if callbackUI != "" {
http.Redirect(w, r, callbackUI, http.StatusFound)
return
}
w.WriteHeader(http.StatusCreated)
}
}
// Openid CodeExchange callback.
func (rh *RouteHandler) OpenIDCodeExchangeCallback() rp.CodeExchangeUserinfoCallback {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string,
relyingParty rp.RelyingParty, info oidc.UserInfo,
) {
email := info.GetEmail()
if email == "" {
rh.c.Log.Error().Msg("couldn't set user record for empty email value")
w.WriteHeader(http.StatusUnauthorized)
return
}
var groups []string
val, ok := info.GetClaim("groups").([]interface{})
if !ok {
rh.c.Log.Info().Msgf("couldn't find any 'groups' claim for user %s", email)
}
for _, group := range val {
groups = append(groups, fmt.Sprint(group))
}
callbackUI, err := OAuth2Callback(rh.c, w, r, state, email, groups)
if err != nil {
if errors.Is(err, zerr.ErrInvalidStateCookie) {
w.WriteHeader(http.StatusUnauthorized)
}
w.WriteHeader(http.StatusInternalServerError)
}
if callbackUI != "" {
http.Redirect(w, r, callbackUI, http.StatusFound)
return
}
w.WriteHeader(http.StatusCreated)
}
}
func (rh *RouteHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
m := rh.c.Metrics.ReceiveMetrics()
zcommon.WriteJSON(w, http.StatusOK, m)
+175 -4
View File
@@ -1,26 +1,36 @@
//go:build sync && scrub && metrics && search && lint
// +build sync,scrub,metrics,search,lint
//go:build sync && scrub && metrics && search && lint && apikey
// +build sync,scrub,metrics,search,lint,apikey
package api_test
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/google/uuid"
"github.com/gorilla/mux"
godigest "github.com/opencontainers/go-digest"
ispec "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/project-zot/mockoidc"
. "github.com/smartystreets/goconvey/convey"
"github.com/zitadel/oidc/pkg/client/rp"
"github.com/zitadel/oidc/pkg/oidc"
"golang.org/x/oauth2"
zerr "zotregistry.io/zot/errors"
"zotregistry.io/zot/pkg/api"
"zotregistry.io/zot/pkg/api/config"
"zotregistry.io/zot/pkg/api/constants"
"zotregistry.io/zot/pkg/extensions"
extconf "zotregistry.io/zot/pkg/extensions/config"
"zotregistry.io/zot/pkg/meta/repodb"
localCtx "zotregistry.io/zot/pkg/requestcontext"
storageTypes "zotregistry.io/zot/pkg/storage/types"
"zotregistry.io/zot/pkg/test"
@@ -29,6 +39,8 @@ import (
var ErrUnexpectedError = errors.New("error: unexpected error")
const sessionStr = "session"
func TestRoutes(t *testing.T) {
Convey("Make a new controller", t, func() {
port := test.GetFreePort()
@@ -36,6 +48,45 @@ func TestRoutes(t *testing.T) {
conf := config.New()
conf.HTTP.Port = port
htpasswdPath := test.MakeHtpasswdFile()
defer os.Remove(htpasswdPath)
mockOIDCServer, err := mockoidc.Run()
if err != nil {
panic(err)
}
defer func() {
err := mockOIDCServer.Shutdown()
if err != nil {
panic(err)
}
}()
mockOIDCConfig := mockOIDCServer.Config()
conf.HTTP.Auth = &config.AuthConfig{
HTPasswd: config.AuthHTPasswd{
Path: htpasswdPath,
},
OpenID: &config.OpenIDConfig{
Providers: map[string]config.OpenIDProviderConfig{
"dex": {
ClientID: mockOIDCConfig.ClientID,
ClientSecret: mockOIDCConfig.ClientSecret,
KeyPath: "",
Issuer: mockOIDCConfig.Issuer,
Scopes: []string{"openid", "email"},
},
},
},
}
defaultVal := true
apiKeyConfig := &extconf.APIKeyConfig{
BaseConfig: extconf.BaseConfig{Enable: &defaultVal},
}
conf.Extensions = &extconf.ExtensionConfig{
APIKey: apiKeyConfig,
}
ctlr := api.NewController(conf)
ctlr.Config.Storage.RootDirectory = t.TempDir()
@@ -50,6 +101,52 @@ func TestRoutes(t *testing.T) {
// NOTE: the url or method itself doesn't matter below since we are calling the handlers directly,
// so path routing is bypassed
Convey("Test GithubCodeExchangeCallback", func() {
callback := rthdlr.GithubCodeExchangeCallback()
ctx := context.TODO()
request, _ := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
response := httptest.NewRecorder()
tokens := &oidc.Tokens{}
relyingParty, err := rp.NewRelyingPartyOAuth(&oauth2.Config{})
So(err, ShouldBeNil)
callback(response, request, tokens, "state", relyingParty)
resp := response.Result()
defer resp.Body.Close()
So(resp, ShouldNotBeNil)
So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized)
})
Convey("Test OAuth2Callback errors", func() {
ctx := context.TODO()
request, _ := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
response := httptest.NewRecorder()
_, err := api.OAuth2Callback(ctlr, response, request, "state", "email", []string{"group"})
So(err, ShouldEqual, zerr.ErrInvalidStateCookie)
session, _ := ctlr.CookieStore.Get(request, "statecookie")
session.Options.Secure = true
session.Options.HttpOnly = true
session.Options.SameSite = http.SameSiteDefaultMode
state := uuid.New().String()
session.Values["state"] = state
// let the session set its own id
err = session.Save(request, response)
So(err, ShouldBeNil)
_, err = api.OAuth2Callback(ctlr, response, request, "state", "email", []string{"group"})
So(err, ShouldEqual, zerr.ErrInvalidStateCookie)
})
Convey("List repositories authz error", func() {
var invalid struct{}
@@ -575,7 +672,7 @@ func TestRoutes(t *testing.T) {
},
&mocks.MockedImageStore{
FullBlobUploadFn: func(repo string, body io.Reader, digest godigest.Digest) (string, int64, error) {
return "session", 0, zerr.ErrBadBlobDigest
return sessionStr, 0, zerr.ErrBadBlobDigest
},
})
So(statusCode, ShouldEqual, http.StatusInternalServerError)
@@ -591,7 +688,7 @@ func TestRoutes(t *testing.T) {
},
&mocks.MockedImageStore{
FullBlobUploadFn: func(repo string, body io.Reader, digest godigest.Digest) (string, int64, error) {
return "session", 20, nil
return sessionStr, 20, nil
},
})
So(statusCode, ShouldEqual, http.StatusInternalServerError)
@@ -1327,6 +1424,80 @@ func TestRoutes(t *testing.T) {
So(resp.StatusCode, ShouldEqual, http.StatusOK)
})
Convey("Test API keys", func() {
var invalid struct{}
ctx := context.TODO()
key := localCtx.GetContextKey()
ctx = context.WithValue(ctx, key, invalid)
request, _ := http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader([]byte{}))
response := httptest.NewRecorder()
extensions.CreateAPIKey(response, request, ctlr.RepoDB, ctlr.CookieStore, ctlr.Log)
resp := response.Result()
defer resp.Body.Close()
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
acCtx := localCtx.AccessControlContext{
Username: username,
}
ctx = context.TODO()
key = localCtx.GetContextKey()
ctx = context.WithValue(ctx, key, acCtx)
request, _ = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader([]byte{}))
response = httptest.NewRecorder()
extensions.CreateAPIKey(response, request, ctlr.RepoDB, ctlr.CookieStore, ctlr.Log)
resp = response.Result()
defer resp.Body.Close()
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
payload := extensions.APIKeyPayload{
Label: "test",
Scopes: []string{"test"},
}
reqBody, err := json.Marshal(payload)
So(err, ShouldBeNil)
request, _ = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(reqBody))
response = httptest.NewRecorder()
extensions.CreateAPIKey(response, request, mocks.RepoDBMock{
AddUserAPIKeyFn: func(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error {
return ErrUnexpectedError
},
}, ctlr.CookieStore, ctlr.Log)
resp = response.Result()
defer resp.Body.Close()
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
request, _ = http.NewRequestWithContext(ctx, http.MethodDelete, baseURL, bytes.NewReader([]byte{}))
response = httptest.NewRecorder()
q := request.URL.Query()
q.Add("id", "apikeyid")
request.URL.RawQuery = q.Encode()
extensions.RevokeAPIKey(response, request, mocks.RepoDBMock{
DeleteUserAPIKeyFn: func(ctx context.Context, id string) error {
return ErrUnexpectedError
},
}, ctlr.CookieStore, ctlr.Log)
resp = response.Result()
defer resp.Body.Close()
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
})
Convey("Helper functions", func() {
testUpdateBlobUpload := func(
query []struct{ k, v string },
+63 -2
View File
@@ -1,5 +1,5 @@
//go:build sync && scrub && metrics && search
// +build sync,scrub,metrics,search
//go:build sync && scrub && metrics && search && apikey
// +build sync,scrub,metrics,search,apikey
package cli_test
@@ -857,6 +857,67 @@ func TestServeMgmtExtension(t *testing.T) {
})
}
func TestServeAPIKeyExtension(t *testing.T) {
oldArgs := os.Args
defer func() { os.Args = oldArgs }()
Convey("apikey implicitly enabled", t, func(c C) {
content := `{
"storage": {
"rootDirectory": "%s"
},
"http": {
"address": "127.0.0.1",
"port": "%s"
},
"log": {
"level": "debug",
"output": "%s"
},
"extensions": {
"apikey": {
}
}
}`
logPath, err := runCLIWithConfig(t.TempDir(), content)
So(err, ShouldBeNil)
data, err := os.ReadFile(logPath)
So(err, ShouldBeNil)
defer os.Remove(logPath) // clean up
So(string(data), ShouldContainSubstring, "\"APIKey\":{\"Enable\":true}")
})
Convey("apikey disabled", t, func(c C) {
content := `{
"storage": {
"rootDirectory": "%s"
},
"http": {
"address": "127.0.0.1",
"port": "%s"
},
"log": {
"level": "debug",
"output": "%s"
},
"extensions": {
"apikey": {
"enable": "false"
}
}
}`
logPath, err := runCLIWithConfig(t.TempDir(), content)
So(err, ShouldBeNil)
data, err := os.ReadFile(logPath)
So(err, ShouldBeNil)
defer os.Remove(logPath) // clean up
So(string(data), ShouldContainSubstring, "\"APIKey\":{\"Enable\":false}")
})
}
func readLogFileAndSearchString(logPath string, stringToMatch string, timeout time.Duration) (bool, error) { //nolint:unparam,lll
ctx, cancelFunc := context.WithTimeout(context.Background(), timeout)
defer cancelFunc()
+52 -4
View File
@@ -361,6 +361,10 @@ func validateConfiguration(config *config.Config) error {
return err
}
if err := validateOpenIDConfig(config); err != nil {
return err
}
if err := validateSync(config); err != nil {
return err
}
@@ -377,7 +381,7 @@ func validateConfiguration(config *config.Config) error {
return err
}
// check authorization config, it should have basic auth enabled or ldap
// check authorization config, it should have basic auth enabled or ldap, api keys or OpenID
if config.HTTP.AccessControl != nil {
// checking for anonymous policy only authorization config: no users, no policies but anonymous policy
if err := validateAuthzPolicies(config); err != nil {
@@ -435,11 +439,42 @@ func validateConfiguration(config *config.Config) error {
return nil
}
func validateOpenIDConfig(config *config.Config) error {
if config.HTTP.Auth != nil && config.HTTP.Auth.OpenID != nil {
for provider, providerConfig := range config.HTTP.Auth.OpenID.Providers {
//nolint: gocritic
if api.IsOpenIDSupported(provider) {
if providerConfig.ClientID == "" || providerConfig.Issuer == "" ||
len(providerConfig.Scopes) == 0 {
log.Error().Err(errors.ErrBadConfig).
Msg("OpenID provider config requires clientid, issuer and scopes parameters")
return errors.ErrBadConfig
}
} else if api.IsOauth2Supported(provider) {
if providerConfig.ClientID == "" || len(providerConfig.Scopes) == 0 {
log.Error().Err(errors.ErrBadConfig).
Msg("OAuth2 provider config requires clientid and scopes parameters")
return errors.ErrBadConfig
}
} else {
log.Error().Err(errors.ErrBadConfig).
Msg("unsupported openid/oauth2 provider")
return errors.ErrBadConfig
}
}
}
return nil
}
func validateAuthzPolicies(config *config.Config) error {
if (config.HTTP.Auth == nil || (config.HTTP.Auth.HTPasswd.Path == "" && config.HTTP.Auth.LDAP == nil)) &&
!authzContainsOnlyAnonymousPolicy(config) {
if (config.HTTP.Auth == nil || (config.HTTP.Auth.HTPasswd.Path == "" && config.HTTP.Auth.LDAP == nil &&
config.HTTP.Auth.OpenID == nil)) && !authzContainsOnlyAnonymousPolicy(config) {
log.Error().Err(errors.ErrBadConfig).
Msg("access control config requires httpasswd, ldap authentication " +
Msg("access control config requires one of httpasswd, ldap or openid authentication " +
"or using only 'anonymousPolicy' policies")
return errors.ErrBadConfig
@@ -484,6 +519,13 @@ func applyDefaultValues(config *config.Config, viperInstance *viper.Viper) {
// Note: In case mgmt is not empty the config.Extensions will not be nil and we will not reach here
config.Extensions.Mgmt = &extconf.MgmtConfig{}
}
_, ok = extMap["apikey"]
if ok {
// we found a config like `"extensions": {"mgmt:": {}}`
// Note: In case mgmt is not empty the config.Extensions will not be nil and we will not reach here
config.Extensions.APIKey = &extconf.APIKeyConfig{}
}
}
if config.Extensions != nil {
@@ -550,6 +592,12 @@ func applyDefaultValues(config *config.Config, viperInstance *viper.Viper) {
}
}
if config.Extensions.APIKey != nil {
if config.Extensions.APIKey.Enable == nil {
config.Extensions.APIKey.Enable = &defaultVal
}
}
if config.Extensions.Scrub != nil {
if config.Extensions.Scrub.Enable == nil {
config.Extensions.Scrub.Enable = &defaultVal
+65
View File
@@ -952,6 +952,71 @@ func TestVerify(t *testing.T) {
So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldPanic)
})
Convey("Test verify openid config with missing parameter", t, func(c C) {
tmpfile, err := os.CreateTemp("", "zot-test*.json")
So(err, ShouldBeNil)
defer os.Remove(tmpfile.Name()) // clean up
content := []byte(`{"distSpecVersion":"1.1.0-dev","storage":{"rootDirectory":"/tmp/zot"},
"http":{"address":"127.0.0.1","port":"8080","realm":"zot",
"auth":{"openid":{"providers":{"dex":{"issuer":"http://127.0.0.1:5556/dex"}}}}},
"log":{"level":"debug"}}`)
_, err = tmpfile.Write(content)
So(err, ShouldBeNil)
err = tmpfile.Close()
So(err, ShouldBeNil)
os.Args = []string{"cli_test", "verify", tmpfile.Name()}
So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldPanic)
})
Convey("Test verify oauth2 config with missing parameter", t, func(c C) {
tmpfile, err := os.CreateTemp("", "zot-test*.json")
So(err, ShouldBeNil)
defer os.Remove(tmpfile.Name()) // clean up
content := []byte(`{"distSpecVersion":"1.1.0-dev","storage":{"rootDirectory":"/tmp/zot"},
"http":{"address":"127.0.0.1","port":"8080","realm":"zot",
"auth":{"openid":{"providers":{"github":{"clientid":"client_id"}}}}},
"log":{"level":"debug"}}`)
_, err = tmpfile.Write(content)
So(err, ShouldBeNil)
err = tmpfile.Close()
So(err, ShouldBeNil)
os.Args = []string{"cli_test", "verify", tmpfile.Name()}
So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldPanic)
})
Convey("Test verify openid config with unsupported provider", t, func(c C) {
tmpfile, err := os.CreateTemp("", "zot-test*.json")
So(err, ShouldBeNil)
defer os.Remove(tmpfile.Name()) // clean up
content := []byte(`{"distSpecVersion":"1.1.0-dev","storage":{"rootDirectory":"/tmp/zot"},
"http":{"address":"127.0.0.1","port":"8080","realm":"zot",
"auth":{"openid":{"providers":{"unsupported":{"issuer":"http://127.0.0.1:5556/dex"}}}}},
"log":{"level":"debug"}}`)
_, err = tmpfile.Write(content)
So(err, ShouldBeNil)
err = tmpfile.Close()
So(err, ShouldBeNil)
os.Args = []string{"cli_test", "verify", tmpfile.Name()}
So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldPanic)
})
Convey("Test verify openid config without apikey extension enabled", t, func(c C) {
tmpfile, err := os.CreateTemp("", "zot-test*.json")
So(err, ShouldBeNil)
defer os.Remove(tmpfile.Name()) // clean up
content := []byte(`{"distSpecVersion":"1.1.0-dev","storage":{"rootDirectory":"/tmp/zot"},
"http":{"address":"127.0.0.1","port":"8080","realm":"zot",
"auth":{"openid":{"providers":{"dex":{"issuer":"http://127.0.0.1:5556/dex",
"clientid":"client_id","scopes":["openid"]}}}}},
"log":{"level":"debug"}}`)
_, err = tmpfile.Write(content)
So(err, ShouldBeNil)
err = tmpfile.Close()
So(err, ShouldBeNil)
os.Args = []string{"cli_test", "verify", tmpfile.Name()}
So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldNotPanic)
})
Convey("Test verify config with missing basedn key", t, func(c C) {
tmpfile, err := os.CreateTemp("", "zot-test*.json")
So(err, ShouldBeNil)
+54 -3
View File
@@ -2,14 +2,17 @@ package common
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
jsoniter "github.com/json-iterator/go"
"zotregistry.io/zot/pkg/api/constants"
apiErr "zotregistry.io/zot/pkg/api/errors"
"zotregistry.io/zot/pkg/log"
)
func AllowedMethods(methods ...string) []string {
@@ -32,7 +35,8 @@ func ACHeadersHandler(allowedMethods ...string) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
resp.Header().Set("Access-Control-Allow-Methods", headerValue)
resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type")
resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName)
resp.Header().Set("Access-Control-Allow-Credentials", "true")
if req.Method == http.MethodOptions {
return
@@ -43,9 +47,20 @@ func ACHeadersHandler(allowedMethods ...string) mux.MiddlewareFunc {
}
}
func AuthzFail(w http.ResponseWriter, realm string, delay int) {
func AuthzFail(w http.ResponseWriter, r *http.Request, realm string, delay int) {
time.Sleep(time.Duration(delay) * time.Second)
w.Header().Set("WWW-Authenticate", realm)
// don't send auth headers if request is coming from UI
if r.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue {
if realm == "" {
realm = "Authorization Required"
}
realm = "Basic realm=" + strconv.Quote(realm)
w.Header().Set("WWW-Authenticate", realm)
}
w.Header().Set("Content-Type", "application/json")
WriteJSON(w, http.StatusForbidden, apiErr.NewErrorList(apiErr.NewError(apiErr.DENIED)))
}
@@ -66,3 +81,39 @@ func WriteData(w http.ResponseWriter, status int, mediaType string, data []byte)
w.WriteHeader(status)
_, _ = w.Write(data)
}
/*
GetAuthUserFromRequestSession returns identity
and auth status if on the request's cookie session is a logged in user.
*/
func GetAuthUserFromRequestSession(cookieStore sessions.Store, request *http.Request, log log.Logger,
) (string, bool) {
session, err := cookieStore.Get(request, "session")
if err != nil {
log.Error().Err(err).Msg("can not decode existing session")
// expired cookie, no need to return err
return "", false
}
// at this point we should have a session set on cookie.
// if created in the earlier Get() call then user is not logged in with sessions.
if session.IsNew {
return "", false
}
authenticated := session.Values["authStatus"]
if authenticated != true {
log.Error().Msg("can not get `user` session value")
return "", false
}
identity, ok := session.Values["user"].(string)
if !ok {
log.Error().Msg("can not get `user` session value")
return "", false
}
return identity, true
}
+66
View File
@@ -0,0 +1,66 @@
# `API keys`
zot allows authentication for REST API calls using your API key as an alternative to your password.
* User can create/revoke his API key.
* Can not be retrieved, it is shown to the user only the first time is created.
* An API key has the same rights as the user who generated it.
## API keys REST API
### Create API Key
**Description**: Create an API key for the current user.
**Usage**: POST /v2/_zot/ext/apikey
**Produces**: application/json
**Sample input**:
```
POST /api/security/apiKey
Body: {"label": "git", "scopes": ["repo1", "repo2"]}'
```
**Example cURL**
```
curl -u user:password -X POST http://localhost:8080/v2/_zot/ext/apikey -d '{"label": "myLabel", "scopes": ["repo1", "repo2"]}'
```
**Sample output**:
```json
{
"createdAt": "2023-05-05T15:39:28.420926+03:00",
"creatorUa": "curl/7.68.0",
"generatedBy": "manual",
"lastUsed": "2023-05-05T15:39:28.4209282+03:00",
"label": "git",
"scopes": [
"repo1",
"repo2"
],
"uuid": "46a45ce7-5d92-498a-a9cb-9654b1da3da1",
"apiKey": "zak_e77bcb9e9f634f1581756abbf9ecd269"
}
```
**Using API keys cURL**
```
curl -u user:zak_e77bcb9e9f634f1581756abbf9ecd269 http://localhost:8080/v2/_catalog
```
### Revoke API Key
**Description**: Revokes one current user API key by api key UUID
**Usage**: DELETE /api/security/apiKey?id=$uuid
**Produces**: application/json
**Example cURL**
```
curl -u user:password -X DELETE http://localhost:8080/v2/_zot/ext/apikey?id=46a45ce7-5d92-498a-a9cb-9654b1da3da1
```
+1
View File
@@ -8,6 +8,7 @@ Component | Endpoint | Description
[`search`](search/search.md) | `/v2/_zot/ext/search` | efficient and enhanced registry search capabilities using graphQL backend
[`mgmt`](mgmt.md) | `/v2/_zot/ext/mgmt` | config management
[`userprefs`](userprefs.md) | `/v2/_zot/ext/userprefs` | change user preferences
[`apikey`](README_apikey.md) | `/v2/_zot/ext/apikey` | user api keys management
# References
+5
View File
@@ -19,6 +19,11 @@ type ExtensionConfig struct {
Lint *LintConfig
UI *UIConfig
Mgmt *MgmtConfig
APIKey *APIKeyConfig
}
type APIKeyConfig struct {
BaseConfig `mapstructure:",squash"`
}
type MgmtConfig struct {
+197
View File
@@ -0,0 +1,197 @@
//go:build apikey
// +build apikey
package extensions
import (
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
guuid "github.com/gofrs/uuid"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
jsoniter "github.com/json-iterator/go"
godigest "github.com/opencontainers/go-digest"
"zotregistry.io/zot/pkg/api/config"
"zotregistry.io/zot/pkg/api/constants"
zcommon "zotregistry.io/zot/pkg/common"
"zotregistry.io/zot/pkg/log"
"zotregistry.io/zot/pkg/meta/repodb"
)
func SetupAPIKeyRoutes(config *config.Config, router *mux.Router, repoDB repodb.RepoDB,
cookieStore sessions.Store, log log.Logger,
) {
if config.Extensions.APIKey != nil && *config.Extensions.APIKey.Enable {
log.Info().Msg("setting up api key routes")
allowedMethods := zcommon.AllowedMethods(http.MethodPost, http.MethodDelete)
apiKeyRouter := router.PathPrefix(constants.ExtAPIKey).Subrouter()
apiKeyRouter.Use(zcommon.ACHeadersHandler(allowedMethods...))
apiKeyRouter.Use(zcommon.AddExtensionSecurityHeaders())
apiKeyRouter.Methods(allowedMethods...).Handler(HandleAPIKeyRequest(repoDB, cookieStore, log))
}
}
type APIKeyPayload struct { //nolint:revive
Label string `json:"label"`
Scopes []string `json:"scopes"`
}
func HandleAPIKeyRequest(repoDB repodb.RepoDB, cookieStore sessions.Store,
log log.Logger,
) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
switch req.Method {
case http.MethodPost:
CreateAPIKey(resp, req, repoDB, cookieStore, log) //nolint:contextcheck
return
case http.MethodDelete:
RevokeAPIKey(resp, req, repoDB, cookieStore, log) //nolint:contextcheck
return
}
})
}
// CreateAPIKey godoc
// @Summary Create an API key for the current user
// @Description Can create an api key for a logged in user, based on the provided label and scopes.
// @Accept json
// @Produce json
// @Success 201 {string} string "created"
// @Failure 401 {string} string "unauthorized"
// @Failure 500 {string} string "internal server error"
// @Router /v2/_zot/ext/apikey [post].
func CreateAPIKey(resp http.ResponseWriter, req *http.Request, repoDB repodb.RepoDB,
cookieStore sessions.Store, log log.Logger,
) {
var payload APIKeyPayload
body, err := io.ReadAll(req.Body)
if err != nil {
log.Error().Msg("unable to read request body")
resp.WriteHeader(http.StatusInternalServerError)
return
}
err = json.Unmarshal(body, &payload)
if err != nil {
log.Error().Err(err).Msg("unable to unmarshal body")
resp.WriteHeader(http.StatusInternalServerError)
return
}
apiKeyBase, err := guuid.NewV4()
if err != nil {
log.Error().Err(err).Msg("unable to generate uuid")
resp.WriteHeader(http.StatusInternalServerError)
return
}
apiKey := strings.ReplaceAll(apiKeyBase.String(), "-", "")
hashedAPIKey := hashUUID(apiKey)
// will be used for identifying a specific api key
apiKeyID, err := guuid.NewV4()
if err != nil {
log.Error().Err(err).Msg("unable to generate uuid")
resp.WriteHeader(http.StatusInternalServerError)
return
}
apiKeyDetails := &repodb.APIKeyDetails{
CreatedAt: time.Now(),
LastUsed: time.Now(),
CreatorUA: req.UserAgent(),
GeneratedBy: "manual",
Label: payload.Label,
Scopes: payload.Scopes,
UUID: apiKeyID.String(),
}
err = repoDB.AddUserAPIKey(req.Context(), hashedAPIKey, apiKeyDetails)
if err != nil {
log.Error().Err(err).Msg("error storing API key")
resp.WriteHeader(http.StatusInternalServerError)
return
}
apiKeyResponse := struct {
repodb.APIKeyDetails
APIKey string `json:"apiKey"`
}{
APIKey: fmt.Sprintf("%s%s", constants.APIKeysPrefix, apiKey),
APIKeyDetails: *apiKeyDetails,
}
json := jsoniter.ConfigCompatibleWithStandardLibrary
data, err := json.Marshal(apiKeyResponse)
if err != nil {
log.Error().Err(err).Msg("unable to marshal api key response")
resp.WriteHeader(http.StatusInternalServerError)
return
}
resp.Header().Set("Content-Type", constants.DefaultMediaType)
resp.WriteHeader(http.StatusCreated)
_, _ = resp.Write(data)
}
// RevokeAPIKey godoc
// @Summary Revokes one current user API key
// @Description Revokes one current user API key based on given key ID
// @Accept json
// @Produce json
// @Param id path string true "api token id (UUID)"
// @Success 200 {string} string "ok"
// @Failure 500 {string} string "internal server error"
// @Failure 401 {string} string "unauthorized"
// @Failure 400 {string} string "bad request"
// @Router /v2/_zot/ext/apikey?id=UUID [delete].
func RevokeAPIKey(resp http.ResponseWriter, req *http.Request, repoDB repodb.RepoDB,
cookieStore sessions.Store, log log.Logger,
) {
ids, ok := req.URL.Query()["id"]
if !ok || len(ids) != 1 {
resp.WriteHeader(http.StatusBadRequest)
return
}
keyID := ids[0]
err := repoDB.DeleteUserAPIKey(req.Context(), keyID)
if err != nil {
log.Error().Err(err).Str("keyID", keyID).Msg("error deleting API key")
resp.WriteHeader(http.StatusInternalServerError)
return
}
resp.WriteHeader(http.StatusOK)
}
func hashUUID(uuid string) string {
digester := sha256.New()
digester.Write([]byte(uuid))
return godigest.NewDigestFromEncoded(godigest.SHA256, fmt.Sprintf("%x", digester.Sum(nil))).Encoded()
}
@@ -0,0 +1,20 @@
//go:build !apikey
// +build !apikey
package extensions
import (
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"zotregistry.io/zot/pkg/api/config"
"zotregistry.io/zot/pkg/log"
"zotregistry.io/zot/pkg/meta/repodb"
)
func SetupAPIKeyRoutes(config *config.Config, router *mux.Router, repoDB repodb.RepoDB,
cookieStore sessions.Store, log log.Logger,
) {
log.Warn().Msg("skipping setting up API key routes because given zot binary doesn't include this feature," +
"please build a binary that does so")
}
+531
View File
@@ -0,0 +1,531 @@
//go:build apikey
// +build apikey
package extensions_test
import (
"context"
"encoding/json"
"errors"
"net/http"
"os"
"testing"
"github.com/project-zot/mockoidc"
. "github.com/smartystreets/goconvey/convey"
"gopkg.in/resty.v1"
"zotregistry.io/zot/pkg/api"
"zotregistry.io/zot/pkg/api/config"
"zotregistry.io/zot/pkg/api/constants"
"zotregistry.io/zot/pkg/extensions"
extconf "zotregistry.io/zot/pkg/extensions/config"
"zotregistry.io/zot/pkg/meta/repodb"
localCtx "zotregistry.io/zot/pkg/requestcontext"
"zotregistry.io/zot/pkg/test"
"zotregistry.io/zot/pkg/test/mocks"
)
type (
apiKeyResponse struct {
repodb.APIKeyDetails
APIKey string `json:"apiKey"`
}
)
var ErrUnexpectedError = errors.New("unexpected err")
func TestAPIKeys(t *testing.T) {
Convey("Make a new controller", t, func() {
port := test.GetFreePort()
baseURL := test.GetBaseURL(port)
conf := config.New()
conf.HTTP.Port = port
htpasswdPath := test.MakeHtpasswdFile()
defer os.Remove(htpasswdPath)
mockOIDCServer, err := test.MockOIDCRun()
if err != nil {
panic(err)
}
defer func() {
err := mockOIDCServer.Shutdown()
if err != nil {
panic(err)
}
}()
mockOIDCConfig := mockOIDCServer.Config()
conf.HTTP.Auth = &config.AuthConfig{
HTPasswd: config.AuthHTPasswd{
Path: htpasswdPath,
},
OpenID: &config.OpenIDConfig{
Providers: map[string]config.OpenIDProviderConfig{
"dex": {
ClientID: mockOIDCConfig.ClientID,
ClientSecret: mockOIDCConfig.ClientSecret,
KeyPath: "",
Issuer: mockOIDCConfig.Issuer,
Scopes: []string{"openid", "email", "groups"},
},
},
},
}
conf.HTTP.AccessControl = &config.AccessControlConfig{}
defaultVal := true
apiKeyConfig := &extconf.APIKeyConfig{
BaseConfig: extconf.BaseConfig{Enable: &defaultVal},
}
mgmtConfg := &extconf.MgmtConfig{
BaseConfig: extconf.BaseConfig{Enable: &defaultVal},
}
conf.Extensions = &extconf.ExtensionConfig{
APIKey: apiKeyConfig,
Mgmt: mgmtConfg,
}
ctlr := api.NewController(conf)
dir := t.TempDir()
ctlr.Config.Storage.RootDirectory = dir
cm := test.NewControllerManager(ctlr)
cm.StartServer()
defer cm.StopServer()
test.WaitTillServerReady(baseURL)
payload := extensions.APIKeyPayload{
Label: "test",
Scopes: []string{"test"},
}
reqBody, err := json.Marshal(payload)
So(err, ShouldBeNil)
Convey("API key retrieved with basic auth", func() {
// call endpoint with session ( added to client after previous request)
resp, err := resty.R().
SetBody(reqBody).
SetBasicAuth("test", "test").
Post(baseURL + constants.FullAPIKeyPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
user := mockoidc.DefaultUser()
// get API key and email from apikey route response
var apiKeyResponse apiKeyResponse
err = json.Unmarshal(resp.Body(), &apiKeyResponse)
So(err, ShouldBeNil)
email := user.Email
So(email, ShouldNotBeEmpty)
resp, err = resty.R().
SetBasicAuth("test", apiKeyResponse.APIKey).
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
// add another one
resp, err = resty.R().
SetBody(reqBody).
SetBasicAuth("test", "test").
Post(baseURL + constants.FullAPIKeyPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
err = json.Unmarshal(resp.Body(), &apiKeyResponse)
So(err, ShouldBeNil)
resp, err = resty.R().
SetBasicAuth("test", apiKeyResponse.APIKey).
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
})
Convey("API key retrieved with openID", func() {
client := resty.New()
client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
// first login user
resp, err := client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
SetQueryParam("provider", "dex").
Get(baseURL + constants.LoginPath)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
cookies := resp.Cookies()
// call endpoint without session
resp, err = client.R().
SetBody(reqBody).
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Post(baseURL + constants.FullAPIKeyPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
client.SetCookies(cookies)
// call endpoint with session ( added to client after previous request)
resp, err = client.R().
SetBody(reqBody).
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Post(baseURL + constants.FullAPIKeyPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
user := mockoidc.DefaultUser()
// get API key and email from apikey route response
var apiKeyResponse apiKeyResponse
err = json.Unmarshal(resp.Body(), &apiKeyResponse)
So(err, ShouldBeNil)
email := user.Email
So(email, ShouldNotBeEmpty)
resp, err = client.R().
SetBasicAuth(email, apiKeyResponse.APIKey).
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
// trigger errors
ctlr.RepoDB = mocks.RepoDBMock{
GetUserAPIKeyInfoFn: func(hashedKey string) (string, error) {
return "", ErrUnexpectedError
},
}
resp, err = client.R().
SetBasicAuth(email, apiKeyResponse.APIKey).
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError)
ctlr.RepoDB = mocks.RepoDBMock{
GetUserAPIKeyInfoFn: func(hashedKey string) (string, error) {
return user.Email, nil
},
GetUserGroupsFn: func(ctx context.Context) ([]string, error) {
return []string{}, ErrUnexpectedError
},
}
resp, err = client.R().
SetBasicAuth(email, apiKeyResponse.APIKey).
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError)
ctlr.RepoDB = mocks.RepoDBMock{
GetUserAPIKeyInfoFn: func(hashedKey string) (string, error) {
return user.Email, nil
},
UpdateUserAPIKeyLastUsedFn: func(ctx context.Context, hashedKey string) error {
return ErrUnexpectedError
},
}
resp, err = client.R().
SetBasicAuth(email, apiKeyResponse.APIKey).
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError)
client = resty.New()
// call endpoint without session
resp, err = client.R().
SetBody(reqBody).
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Post(baseURL + constants.FullAPIKeyPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
})
Convey("Login with openid and create API key", func() {
client := resty.New()
// mgmt should work both unauthenticated and authenticated
resp, err := client.R().
Get(baseURL + constants.FullMgmtPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
// first login user
resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
SetQueryParam("provider", "dex").
Get(baseURL + constants.LoginPath)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
client.SetCookies(resp.Cookies())
// call endpoint with session ( added to client after previous request)
resp, err = client.R().
SetBody(reqBody).
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Post(baseURL + constants.FullAPIKeyPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
var apiKeyResponse apiKeyResponse
err = json.Unmarshal(resp.Body(), &apiKeyResponse)
So(err, ShouldBeNil)
user := mockoidc.DefaultUser()
email := user.Email
So(email, ShouldNotBeEmpty)
resp, err = client.R().
SetBasicAuth(email, apiKeyResponse.APIKey).
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
// auth with API key
// we need new client without session cookie set
client = resty.New()
client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
resp, err = client.R().
SetBasicAuth(email, apiKeyResponse.APIKey).
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
resp, err = client.R().
SetBasicAuth(email, apiKeyResponse.APIKey).
Get(baseURL + constants.FullMgmtPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
// invalid api keys
resp, err = client.R().
SetBasicAuth("invalidEmail", apiKeyResponse.APIKey).
Get(baseURL + constants.FullMgmtPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
resp, err = client.R().
SetBasicAuth(email, "noprefixAPIKey").
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
resp, err = client.R().
SetBasicAuth(email, "zak_notworkingAPIKey").
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: email,
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err = ctlr.RepoDB.DeleteUserData(ctx)
So(err, ShouldBeNil)
resp, err = client.R().
SetBasicAuth(email, apiKeyResponse.APIKey).
Get(baseURL + constants.FullMgmtPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError)
client = resty.New()
client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
// without creds should work
resp, err = client.R().
Get(baseURL + constants.FullMgmtPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
// login again
resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
SetQueryParam("provider", "dex").
Get(baseURL + constants.LoginPath)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
client.SetCookies(resp.Cookies())
resp, err = client.R().
SetBody(reqBody).
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Post(baseURL + constants.FullAPIKeyPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusCreated)
err = json.Unmarshal(resp.Body(), &apiKeyResponse)
So(err, ShouldBeNil)
// should work with session
resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Get(baseURL + constants.FullMgmtPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
// should work with api key
resp, err = client.R().
SetBasicAuth(email, apiKeyResponse.APIKey).
Get(baseURL + constants.FullMgmtPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
resp, err = client.R().
SetBasicAuth(email, apiKeyResponse.APIKey).
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
err = json.Unmarshal(resp.Body(), &apiKeyResponse)
So(err, ShouldBeNil)
// delete api key
resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
SetQueryParam("id", apiKeyResponse.UUID).
Delete(baseURL + constants.FullAPIKeyPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Delete(baseURL + constants.FullAPIKeyPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusBadRequest)
resp, err = client.R().
SetBasicAuth(email, apiKeyResponse.APIKey).
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
resp, err = client.R().
SetBasicAuth("test", "test").
SetQueryParam("id", apiKeyResponse.UUID).
Delete(baseURL + constants.FullAPIKeyPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
// unsupported method
resp, err = client.R().
Put(baseURL + constants.FullAPIKeyPrefix)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusMethodNotAllowed)
})
})
}
func TestAPIKeysOpenDBError(t *testing.T) {
Convey("Test API keys - unable to create database", t, func() {
conf := config.New()
htpasswdPath := test.MakeHtpasswdFile()
defer os.Remove(htpasswdPath)
mockOIDCServer, err := test.MockOIDCRun()
if err != nil {
panic(err)
}
defer func() {
err := mockOIDCServer.Shutdown()
if err != nil {
panic(err)
}
}()
mockOIDCConfig := mockOIDCServer.Config()
conf.HTTP.Auth = &config.AuthConfig{
HTPasswd: config.AuthHTPasswd{
Path: htpasswdPath,
},
OpenID: &config.OpenIDConfig{
Providers: map[string]config.OpenIDProviderConfig{
"dex": {
ClientID: mockOIDCConfig.ClientID,
ClientSecret: mockOIDCConfig.ClientSecret,
KeyPath: "",
Issuer: mockOIDCConfig.Issuer,
Scopes: []string{"openid", "email"},
},
},
},
}
defaultVal := true
apiKeyConfig := &extconf.APIKeyConfig{
BaseConfig: extconf.BaseConfig{Enable: &defaultVal},
}
conf.Extensions = &extconf.ExtensionConfig{
APIKey: apiKeyConfig,
}
ctlr := api.NewController(conf)
dir := t.TempDir()
err = os.Chmod(dir, 0o000)
So(err, ShouldBeNil)
ctlr.Config.Storage.RootDirectory = dir
cm := test.NewControllerManager(ctlr)
So(func() {
cm.StartServer()
}, ShouldPanic)
})
}
+14 -1
View File
@@ -36,12 +36,19 @@ type BearerConfig struct {
Service string `json:"service,omitempty"`
}
type OpenIDProviderConfig struct{}
type OpenIDConfig struct {
Providers map[string]OpenIDProviderConfig `json:"providers,omitempty" mapstructure:"providers"`
}
type Auth struct {
HTPasswd *HTPasswd `json:"htpasswd,omitempty" mapstructure:"htpasswd"`
Bearer *BearerConfig `json:"bearer,omitempty" mapstructure:"bearer"`
LDAP *struct {
Address string `json:"address,omitempty" mapstructure:"address"`
} `json:"ldap,omitempty" mapstructure:"ldap"`
OpenID *OpenIDConfig `json:"openid,omitempty" mapstructure:"openid"`
}
type StrippedConfig struct {
@@ -60,8 +67,10 @@ func (auth Auth) MarshalJSON() ([]byte, error) {
type localAuth Auth
if auth.Bearer == nil && auth.LDAP == nil &&
auth.HTPasswd.Path == "" {
auth.HTPasswd.Path == "" &&
(auth.OpenID == nil || len(auth.OpenID.Providers) == 0) {
auth.HTPasswd = nil
auth.OpenID = nil
return json.Marshal((localAuth)(auth))
}
@@ -72,6 +81,10 @@ func (auth Auth) MarshalJSON() ([]byte, error) {
auth.HTPasswd.Path = ""
}
if auth.OpenID != nil && len(auth.OpenID.Providers) == 0 {
auth.OpenID = nil
}
auth.LDAP = nil
return json.Marshal((localAuth)(auth))
+1
View File
@@ -39,6 +39,7 @@ func SetupUserPreferencesRoutes(config *config.Config, router *mux.Router, store
userprefsRouter := router.PathPrefix(constants.ExtUserPreferences).Subrouter()
userprefsRouter.Use(zcommon.ACHeadersHandler(allowedMethods...))
userprefsRouter.Use(zcommon.AddExtensionSecurityHeaders())
userprefsRouter.HandleFunc("", HandleUserPrefs(repoDB, log)).Methods(allowedMethods...)
}
}
+166 -24
View File
@@ -1,5 +1,5 @@
//go:build sync || metrics || mgmt
// +build sync metrics mgmt
//go:build sync || metrics || mgmt || apikey
// +build sync metrics mgmt apikey
package extensions_test
@@ -128,6 +128,20 @@ func TestMgmtExtension(t *testing.T) {
defaultValue := true
mockOIDCServer, err := test.MockOIDCRun()
if err != nil {
panic(err)
}
defer func() {
err := mockOIDCServer.Shutdown()
if err != nil {
panic(err)
}
}()
mockOIDCConfig := mockOIDCServer.Config()
Convey("Verify mgmt route enabled with htpasswd", t, func() {
htpasswdPath := test.MakeHtpasswdFile()
conf.HTTP.Auth.HTPasswd.Path = htpasswdPath
@@ -145,7 +159,7 @@ func TestMgmtExtension(t *testing.T) {
ctlr := api.NewController(conf)
subPaths := make(map[string]config.StorageConfig)
subPaths["/a"] = config.StorageConfig{}
subPaths["/a"] = config.StorageConfig{RootDirectory: t.TempDir()}
ctlr.Config.Storage.RootDirectory = globalDir
ctlr.Config.Storage.SubPaths = subPaths
@@ -158,6 +172,13 @@ func TestMgmtExtension(t *testing.T) {
So(string(data), ShouldContainSubstring, "setting up mgmt routes")
Convey("unsupported http method call", func() {
// without credentials
resp, err := resty.R().Patch(baseURL + constants.FullMgmtPrefix)
So(err, ShouldBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusMethodNotAllowed)
})
// without credentials
resp, err := resty.R().Get(baseURL + constants.FullMgmtPrefix)
So(err, ShouldBeNil)
@@ -210,9 +231,9 @@ func TestMgmtExtension(t *testing.T) {
ctlr := api.NewController(conf)
subPaths := make(map[string]config.StorageConfig)
subPaths["/a"] = config.StorageConfig{}
subPaths["/a"] = config.StorageConfig{RootDirectory: t.TempDir()}
ctlr.Config.Storage.RootDirectory = globalDir
ctlr.Config.Storage.RootDirectory = t.TempDir()
ctlr.Config.Storage.SubPaths = subPaths
ctlrManager := test.NewControllerManager(ctlr)
@@ -259,9 +280,9 @@ func TestMgmtExtension(t *testing.T) {
ctlr := api.NewController(conf)
subPaths := make(map[string]config.StorageConfig)
subPaths["/a"] = config.StorageConfig{}
subPaths["/a"] = config.StorageConfig{RootDirectory: t.TempDir()}
ctlr.Config.Storage.RootDirectory = globalDir
ctlr.Config.Storage.RootDirectory = t.TempDir()
ctlr.Config.Storage.SubPaths = subPaths
ctlrManager := test.NewControllerManager(ctlr)
@@ -325,11 +346,7 @@ func TestMgmtExtension(t *testing.T) {
ctlr := api.NewController(conf)
subPaths := make(map[string]config.StorageConfig)
subPaths["/a"] = config.StorageConfig{}
ctlr.Config.Storage.RootDirectory = globalDir
ctlr.Config.Storage.SubPaths = subPaths
ctlr.Config.Storage.RootDirectory = t.TempDir()
ctlrManager := test.NewControllerManager(ctlr)
ctlrManager.StartAndWait(port)
@@ -396,9 +413,9 @@ func TestMgmtExtension(t *testing.T) {
ctlr := api.NewController(conf)
subPaths := make(map[string]config.StorageConfig)
subPaths["/a"] = config.StorageConfig{}
subPaths["/a"] = config.StorageConfig{RootDirectory: t.TempDir()}
ctlr.Config.Storage.RootDirectory = globalDir
ctlr.Config.Storage.RootDirectory = t.TempDir()
ctlr.Config.Storage.SubPaths = subPaths
ctlrManager := test.NewControllerManager(ctlr)
@@ -445,11 +462,7 @@ func TestMgmtExtension(t *testing.T) {
ctlr := api.NewController(conf)
subPaths := make(map[string]config.StorageConfig)
subPaths["/a"] = config.StorageConfig{}
ctlr.Config.Storage.RootDirectory = globalDir
ctlr.Config.Storage.SubPaths = subPaths
ctlr.Config.Storage.RootDirectory = t.TempDir()
ctlrManager := test.NewControllerManager(ctlr)
ctlrManager.StartAndWait(port)
@@ -474,6 +487,110 @@ func TestMgmtExtension(t *testing.T) {
So(mgmtResp.HTTP.Auth.Bearer.Service, ShouldEqual, "service")
})
Convey("Verify mgmt route enabled with openID", t, func() {
conf.HTTP.Auth.HTPasswd.Path = ""
conf.HTTP.Auth.LDAP = nil
conf.HTTP.Auth.Bearer = nil
openIDProviders := make(map[string]config.OpenIDProviderConfig)
openIDProviders["dex"] = config.OpenIDProviderConfig{
ClientID: mockOIDCConfig.ClientID,
ClientSecret: mockOIDCConfig.ClientSecret,
Issuer: mockOIDCConfig.Issuer,
}
conf.HTTP.Auth.OpenID = &config.OpenIDConfig{
Providers: openIDProviders,
}
conf.Extensions = &extconf.ExtensionConfig{}
conf.Extensions.Mgmt = &extconf.MgmtConfig{
BaseConfig: extconf.BaseConfig{
Enable: &defaultValue,
},
}
conf.Log.Output = logFile.Name()
defer os.Remove(logFile.Name()) // cleanup
ctlr := api.NewController(conf)
ctlr.Config.Storage.RootDirectory = t.TempDir()
ctlrManager := test.NewControllerManager(ctlr)
ctlrManager.StartAndWait(port)
defer ctlrManager.StopServer()
data, _ := os.ReadFile(logFile.Name())
So(string(data), ShouldContainSubstring, "setting up mgmt routes")
// without credentials
resp, err := resty.R().Get(baseURL + constants.FullMgmtPrefix)
So(err, ShouldBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
mgmtResp := extensions.StrippedConfig{}
err = json.Unmarshal(resp.Body(), &mgmtResp)
t.Logf("resp: %v", mgmtResp.HTTP.Auth.OpenID)
So(err, ShouldBeNil)
So(mgmtResp.HTTP.Auth.HTPasswd, ShouldBeNil)
So(mgmtResp.HTTP.Auth.LDAP, ShouldBeNil)
So(mgmtResp.HTTP.Auth.Bearer, ShouldBeNil)
So(mgmtResp.HTTP.Auth.OpenID, ShouldNotBeNil)
So(mgmtResp.HTTP.Auth.OpenID.Providers, ShouldNotBeEmpty)
})
Convey("Verify mgmt route enabled with empty openID provider list", t, func() {
htpasswdPath := test.MakeHtpasswdFile()
conf.HTTP.Auth.HTPasswd.Path = htpasswdPath
conf.HTTP.Auth.LDAP = nil
conf.HTTP.Auth.Bearer = nil
openIDProviders := make(map[string]config.OpenIDProviderConfig)
conf.HTTP.Auth.OpenID = &config.OpenIDConfig{
Providers: openIDProviders,
}
conf.Extensions = &extconf.ExtensionConfig{}
conf.Extensions.Mgmt = &extconf.MgmtConfig{
BaseConfig: extconf.BaseConfig{
Enable: &defaultValue,
},
}
conf.Log.Output = logFile.Name()
defer os.Remove(logFile.Name()) // cleanup
ctlr := api.NewController(conf)
ctlr.Config.Storage.RootDirectory = t.TempDir()
ctlrManager := test.NewControllerManager(ctlr)
ctlrManager.StartAndWait(port)
defer ctlrManager.StopServer()
data, _ := os.ReadFile(logFile.Name())
So(string(data), ShouldContainSubstring, "setting up mgmt routes")
// without credentials
resp, err := resty.R().Get(baseURL + constants.FullMgmtPrefix)
So(err, ShouldBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
mgmtResp := extensions.StrippedConfig{}
err = json.Unmarshal(resp.Body(), &mgmtResp)
t.Logf("resp: %v", mgmtResp.HTTP.Auth.OpenID)
So(err, ShouldBeNil)
So(mgmtResp.HTTP.Auth.HTPasswd, ShouldNotBeNil)
So(mgmtResp.HTTP.Auth.LDAP, ShouldBeNil)
So(mgmtResp.HTTP.Auth.Bearer, ShouldBeNil)
So(mgmtResp.HTTP.Auth.OpenID, ShouldBeNil)
})
Convey("Verify mgmt route enabled without any auth", t, func() {
globalDir := t.TempDir()
conf := config.New()
@@ -499,11 +616,7 @@ func TestMgmtExtension(t *testing.T) {
ctlr := api.NewController(conf)
subPaths := make(map[string]config.StorageConfig)
subPaths["/a"] = config.StorageConfig{}
ctlr.Config.Storage.RootDirectory = globalDir
ctlr.Config.Storage.SubPaths = subPaths
ctlr.Config.Storage.RootDirectory = t.TempDir()
ctlrManager := test.NewControllerManager(ctlr)
ctlrManager.StartAndWait(port)
@@ -856,3 +969,32 @@ func TestAllowedMethodsHeaderMgmt(t *testing.T) {
So(resp.StatusCode(), ShouldEqual, http.StatusNoContent)
})
}
func TestAllowedMethodsHeaderAPIKey(t *testing.T) {
defaultVal := true
Convey("Test http options response", t, func() {
conf := config.New()
port := test.GetFreePort()
conf.HTTP.Port = port
conf.Extensions = &extconf.ExtensionConfig{
APIKey: &extconf.APIKeyConfig{
BaseConfig: extconf.BaseConfig{Enable: &defaultVal},
},
}
baseURL := test.GetBaseURL(port)
ctlr := api.NewController(conf)
ctlr.Config.Storage.RootDirectory = t.TempDir()
ctrlManager := test.NewControllerManager(ctlr)
ctrlManager.StartAndWait(port)
defer ctrlManager.StopServer()
resp, _ := resty.R().Options(baseURL + constants.FullAPIKeyPrefix)
So(resp, ShouldNotBeNil)
So(resp.Header().Get("Access-Control-Allow-Methods"), ShouldResemble, "POST,DELETE,OPTIONS")
So(resp.StatusCode(), ShouldEqual, http.StatusNoContent)
})
}
+1 -2
View File
@@ -7,6 +7,5 @@ const (
RepoMetadataBucket = "RepoMetadata"
UserDataBucket = "UserData"
VersionBucket = "Version"
StarredReposKey = "StarredReposKey"
BookmarkedReposKey = "BookmarkedReposKey"
UserAPIKeysBucket = "UserAPIKeys"
)
+1 -1
View File
@@ -10,7 +10,7 @@ import (
type DBDriverParameters struct {
Endpoint, Region, RepoMetaTablename, ManifestDataTablename, IndexDataTablename,
VersionTablename, UserDataTablename string
UserDataTablename, APIKeyTablename, VersionTablename string
}
func GetDynamoClient(params DBDriverParameters) (*dynamodb.Client, error) {
+334 -151
View File
@@ -3,6 +3,7 @@ package bolt
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
@@ -60,6 +61,11 @@ func NewBoltDBWrapper(boltDB *bbolt.DB, log log.Logger) (*DBWrapper, error) {
return err
}
_, err = transaction.CreateBucketIfNotExists([]byte(bolt.UserAPIKeysBucket))
if err != nil {
return err
}
return nil
})
if err != nil {
@@ -1680,39 +1686,26 @@ func (bdw *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) (repodb.T
var res repodb.ToggleState
if err := bdw.DB.Update(func(tx *bbolt.Tx) error { //nolint:varnamelen
userdb := tx.Bucket([]byte(bolt.UserDataBucket))
userBucket, err := userdb.CreateBucketIfNotExists([]byte(userid))
if err != nil {
// this is a serious failure
return zerr.ErrUnableToCreateUserBucket
var userData repodb.UserData
err := bdw.getUserData(userid, tx, &userData)
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
return err
}
mdata := userBucket.Get([]byte(bolt.StarredReposKey))
unpacked := []string{}
if mdata != nil {
if err = json.Unmarshal(mdata, &unpacked); err != nil {
return zerr.ErrInvalidOldUserStarredRepos
}
}
isRepoStarred := zcommon.Contains(unpacked, repo)
isRepoStarred := zcommon.Contains(userData.StarredRepos, repo)
if isRepoStarred {
res = repodb.Removed
unpacked = zcommon.RemoveFrom(unpacked, repo)
userData.StarredRepos = zcommon.RemoveFrom(userData.StarredRepos, repo)
} else {
res = repodb.Added
unpacked = append(unpacked, repo)
userData.StarredRepos = append(userData.StarredRepos, repo)
}
var repacked []byte
if repacked, err = json.Marshal(unpacked); err != nil {
return zerr.ErrCouldNotMarshalStarredRepos
}
err = userBucket.Put([]byte(bolt.StarredReposKey), repacked)
err = bdw.setUserData(userid, tx, userData)
if err != nil {
return zerr.ErrCouldNotPersistData
return err
}
repoBuck := tx.Bucket([]byte(bolt.RepoMetadataBucket))
@@ -1755,46 +1748,12 @@ func (bdw *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) (repodb.T
}
func (bdw *DBWrapper) GetStarredRepos(ctx context.Context) ([]string, error) {
starredRepos := make([]string, 0)
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return starredRepos, err
userData, err := bdw.GetUserData(ctx)
if errors.Is(err, zerr.ErrUserDataNotFound) || errors.Is(err, zerr.ErrUserDataNotAllowed) {
return []string{}, nil
}
userid := localCtx.GetUsernameFromContext(acCtx)
err = bdw.DB.View(func(tx *bbolt.Tx) error { //nolint:dupl
if userid == "" {
return nil
}
userdb := tx.Bucket([]byte(bolt.UserDataBucket))
userBucket := userdb.Bucket([]byte(userid))
if userBucket == nil {
return nil
}
mdata := userBucket.Get([]byte(bolt.StarredReposKey))
if mdata == nil {
return nil
}
if err := json.Unmarshal(mdata, &starredRepos); err != nil {
bdw.Log.Info().Str("user", userid).Err(err).Msg("unmarshal error")
return zerr.ErrInvalidOldUserStarredRepos
}
if starredRepos == nil {
starredRepos = make([]string, 0)
}
return nil
})
return starredRepos, err
return userData.StarredRepos, err
}
func (bdw *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (repodb.ToggleState, error) {
@@ -1815,43 +1774,25 @@ func (bdw *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (repo
var res repodb.ToggleState
if err := bdw.DB.Update(func(tx *bbolt.Tx) error { //nolint:dupl
userdb := tx.Bucket([]byte(bolt.UserDataBucket))
userBucket, err := userdb.CreateBucketIfNotExists([]byte(userid))
if err != nil {
// this is a serious failure
return zerr.ErrUnableToCreateUserBucket
if err := bdw.DB.Update(func(transaction *bbolt.Tx) error { //nolint:dupl
var userData repodb.UserData
err := bdw.getUserData(userid, transaction, &userData)
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
return err
}
mdata := userBucket.Get([]byte(bolt.BookmarkedReposKey))
unpacked := []string{}
if mdata != nil {
if err = json.Unmarshal(mdata, &unpacked); err != nil {
return zerr.ErrInvalidOldUserBookmarkedRepos
}
}
isRepoBookmarked := zcommon.Contains(unpacked, repo)
isRepoBookmarked := zcommon.Contains(userData.BookmarkedRepos, repo)
if isRepoBookmarked {
res = repodb.Removed
unpacked = zcommon.RemoveFrom(unpacked, repo)
userData.BookmarkedRepos = zcommon.RemoveFrom(userData.BookmarkedRepos, repo)
} else {
res = repodb.Added
unpacked = append(unpacked, repo)
userData.BookmarkedRepos = append(userData.BookmarkedRepos, repo)
}
var repacked []byte
if repacked, err = json.Marshal(unpacked); err != nil {
return zerr.ErrCouldNotMarshalBookmarkedRepos
}
err = userBucket.Put([]byte(bolt.BookmarkedReposKey), repacked)
if err != nil {
return zerr.ErrUnableToCreateUserBucket
}
return nil
return bdw.setUserData(userid, transaction, userData)
}); err != nil {
return repodb.NotChanged, err
}
@@ -1860,46 +1801,12 @@ func (bdw *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (repo
}
func (bdw *DBWrapper) GetBookmarkedRepos(ctx context.Context) ([]string, error) {
bookmarkedRepos := []string{}
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return bookmarkedRepos, err
userData, err := bdw.GetUserData(ctx)
if errors.Is(err, zerr.ErrUserDataNotFound) || errors.Is(err, zerr.ErrUserDataNotAllowed) {
return []string{}, nil
}
userid := localCtx.GetUsernameFromContext(acCtx)
err = bdw.DB.View(func(tx *bbolt.Tx) error { //nolint:dupl
if userid == "" {
return nil
}
userdb := tx.Bucket([]byte(bolt.UserDataBucket))
userBucket := userdb.Bucket([]byte(userid))
if userBucket == nil {
return nil
}
mdata := userBucket.Get([]byte(bolt.BookmarkedReposKey))
if mdata == nil {
return nil
}
if err := json.Unmarshal(mdata, &bookmarkedRepos); err != nil {
bdw.Log.Info().Str("user", userid).Err(err).Msg("unmarshal error")
return zerr.ErrInvalidOldUserBookmarkedRepos
}
if bookmarkedRepos == nil {
bookmarkedRepos = make([]string, 0)
}
return nil
})
return bookmarkedRepos, err
return userData.BookmarkedRepos, err
}
func (bdw *DBWrapper) PatchDB() error {
@@ -1940,30 +1847,25 @@ func getUserStars(ctx context.Context, transaction *bbolt.Tx) []string {
}
var (
userid = localCtx.GetUsernameFromContext(acCtx)
starredRepos = []string{}
userdb = transaction.Bucket([]byte(bolt.UserDataBucket))
userBucket = userdb.Bucket([]byte(userid))
userData repodb.UserData
userid = localCtx.GetUsernameFromContext(acCtx)
userdb = transaction.Bucket([]byte(bolt.UserDataBucket))
)
if userid == "" {
if userid == "" || userdb == nil {
return []string{}
}
if userBucket == nil {
return []string{}
}
mdata := userBucket.Get([]byte(bolt.StarredReposKey))
mdata := userdb.Get([]byte(userid))
if mdata == nil {
return []string{}
}
if err := json.Unmarshal(mdata, &starredRepos); err != nil {
if err := json.Unmarshal(mdata, &userData); err != nil {
return []string{}
}
return starredRepos
return userData.StarredRepos
}
func getUserBookmarks(ctx context.Context, transaction *bbolt.Tx) []string {
@@ -1973,28 +1875,309 @@ func getUserBookmarks(ctx context.Context, transaction *bbolt.Tx) []string {
}
var (
userid = localCtx.GetUsernameFromContext(acCtx)
bookmarkedRepos = []string{}
userdb = transaction.Bucket([]byte(bolt.UserDataBucket))
userBucket = userdb.Bucket([]byte(userid))
userData repodb.UserData
userid = localCtx.GetUsernameFromContext(acCtx)
userdb = transaction.Bucket([]byte(bolt.UserDataBucket))
)
if userid == "" {
if userid == "" || userdb == nil {
return []string{}
}
if userBucket == nil {
return []string{}
}
mdata := userBucket.Get([]byte(bolt.BookmarkedReposKey))
mdata := userdb.Get([]byte(userid))
if mdata == nil {
return []string{}
}
if err := json.Unmarshal(mdata, &bookmarkedRepos); err != nil {
if err := json.Unmarshal(mdata, &userData); err != nil {
return []string{}
}
return bookmarkedRepos
return userData.BookmarkedRepos
}
func (bdw *DBWrapper) SetUserGroups(ctx context.Context, groups []string) error {
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return err
}
userid := localCtx.GetUsernameFromContext(acCtx)
if userid == "" {
// empty user is anonymous
return zerr.ErrUserDataNotAllowed
}
err = bdw.DB.Update(func(tx *bbolt.Tx) error { //nolint:varnamelen
var userData repodb.UserData
err := bdw.getUserData(userid, tx, &userData)
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
return err
}
userData.Groups = append(userData.Groups, groups...)
err = bdw.setUserData(userid, tx, userData)
return err
})
return err
}
func (bdw *DBWrapper) GetUserGroups(ctx context.Context) ([]string, error) {
userData, err := bdw.GetUserData(ctx)
return userData.Groups, err
}
func (bdw *DBWrapper) UpdateUserAPIKeyLastUsed(ctx context.Context, hashedKey string) error {
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return err
}
userid := localCtx.GetUsernameFromContext(acCtx)
if userid == "" {
// empty user is anonymous
return zerr.ErrUserDataNotAllowed
}
err = bdw.DB.Update(func(tx *bbolt.Tx) error { //nolint:varnamelen
var userData repodb.UserData
err := bdw.getUserData(userid, tx, &userData)
if err != nil {
return err
}
apiKeyDetails := userData.APIKeys[hashedKey]
apiKeyDetails.LastUsed = time.Now()
userData.APIKeys[hashedKey] = apiKeyDetails
err = bdw.setUserData(userid, tx, userData)
return err
})
return err
}
func (bdw *DBWrapper) AddUserAPIKey(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error {
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return err
}
userid := localCtx.GetUsernameFromContext(acCtx)
if userid == "" {
// empty user is anonymous
return zerr.ErrUserDataNotAllowed
}
err = bdw.DB.Update(func(transaction *bbolt.Tx) error {
var userData repodb.UserData
apiKeysbuck := transaction.Bucket([]byte(bolt.UserAPIKeysBucket))
if apiKeysbuck == nil {
return zerr.ErrBucketDoesNotExist
}
err := apiKeysbuck.Put([]byte(hashedKey), []byte(userid))
if err != nil {
return fmt.Errorf("repoDB: error while setting userData for identity %s %w", userid, err)
}
err = bdw.getUserData(userid, transaction, &userData)
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
return err
}
if userData.APIKeys == nil {
userData.APIKeys = make(map[string]repodb.APIKeyDetails)
}
userData.APIKeys[hashedKey] = *apiKeyDetails
err = bdw.setUserData(userid, transaction, userData)
return err
})
return err
}
func (bdw *DBWrapper) DeleteUserAPIKey(ctx context.Context, keyID string) error {
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return err
}
userid := localCtx.GetUsernameFromContext(acCtx)
if userid == "" {
// empty user is anonymous
return zerr.ErrUserDataNotAllowed
}
err = bdw.DB.Update(func(transaction *bbolt.Tx) error {
var userData repodb.UserData
apiKeysbuck := transaction.Bucket([]byte(bolt.UserAPIKeysBucket))
if apiKeysbuck == nil {
return zerr.ErrBucketDoesNotExist
}
err := bdw.getUserData(userid, transaction, &userData)
if err != nil {
return err
}
for hash, apiKeyDetails := range userData.APIKeys {
if apiKeyDetails.UUID == keyID {
delete(userData.APIKeys, hash)
err := apiKeysbuck.Delete([]byte(hash))
if err != nil {
return fmt.Errorf("userDB: error while deleting userAPIKey entry for hash %s %w", hash, err)
}
}
}
return bdw.setUserData(userid, transaction, userData)
})
return err
}
func (bdw *DBWrapper) GetUserAPIKeyInfo(hashedKey string) (string, error) {
var userid string
err := bdw.DB.View(func(tx *bbolt.Tx) error {
buck := tx.Bucket([]byte(bolt.UserAPIKeysBucket))
if buck == nil {
return zerr.ErrBucketDoesNotExist
}
uiBlob := buck.Get([]byte(hashedKey))
if len(uiBlob) == 0 {
return zerr.ErrUserAPIKeyNotFound
}
userid = string(uiBlob)
return nil
})
return userid, err
}
func (bdw *DBWrapper) GetUserData(ctx context.Context) (repodb.UserData, error) {
var userData repodb.UserData
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return userData, err
}
userid := localCtx.GetUsernameFromContext(acCtx)
if userid == "" {
// empty user is anonymous
return userData, zerr.ErrUserDataNotAllowed
}
err = bdw.DB.View(func(tx *bbolt.Tx) error {
return bdw.getUserData(userid, tx, &userData)
})
return userData, err
}
func (bdw *DBWrapper) getUserData(userid string, transaction *bbolt.Tx, res *repodb.UserData) error {
buck := transaction.Bucket([]byte(bolt.UserDataBucket))
if buck == nil {
return zerr.ErrBucketDoesNotExist
}
upBlob := buck.Get([]byte(userid))
if len(upBlob) == 0 {
return zerr.ErrUserDataNotFound
}
err := json.Unmarshal(upBlob, res)
if err != nil {
return err
}
return nil
}
func (bdw *DBWrapper) SetUserData(ctx context.Context, userData repodb.UserData) error {
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return err
}
userid := localCtx.GetUsernameFromContext(acCtx)
if userid == "" {
// empty user is anonymous
return zerr.ErrUserDataNotAllowed
}
err = bdw.DB.Update(func(tx *bbolt.Tx) error {
return bdw.setUserData(userid, tx, userData)
})
return err
}
func (bdw *DBWrapper) setUserData(userid string, transaction *bbolt.Tx, userData repodb.UserData) error {
buck := transaction.Bucket([]byte(bolt.UserDataBucket))
if buck == nil {
return zerr.ErrBucketDoesNotExist
}
upBlob, err := json.Marshal(userData)
if err != nil {
return err
}
err = buck.Put([]byte(userid), upBlob)
if err != nil {
return fmt.Errorf("repoDB: error while setting userData for identity %s %w", userid, err)
}
return nil
}
func (bdw *DBWrapper) DeleteUserData(ctx context.Context) error {
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return err
}
userid := localCtx.GetUsernameFromContext(acCtx)
if userid == "" {
// empty user is anonymous
return zerr.ErrUserDataNotAllowed
}
err = bdw.DB.Update(func(tx *bbolt.Tx) error {
buck := tx.Bucket([]byte(bolt.UserDataBucket))
if buck == nil {
return zerr.ErrBucketDoesNotExist
}
err := buck.Delete([]byte(userid))
if err != nil {
return fmt.Errorf("repoDB: error while deleting userData for identity %s %w", userid, err)
}
return nil
})
return err
}
@@ -2,7 +2,10 @@ package bolt_test
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"math"
"testing"
"github.com/opencontainers/go-digest"
@@ -10,6 +13,7 @@ import (
. "github.com/smartystreets/goconvey/convey"
"go.etcd.io/bbolt"
zerr "zotregistry.io/zot/errors"
"zotregistry.io/zot/pkg/log"
"zotregistry.io/zot/pkg/meta/bolt"
"zotregistry.io/zot/pkg/meta/repodb"
@@ -21,7 +25,6 @@ import (
func TestWrapperErrors(t *testing.T) {
Convey("Errors", t, func() {
ctx := context.Background()
tmpDir := t.TempDir()
boltDBParams := bolt.DBParameters{RootDir: tmpDir}
boltDriver, err := bolt.GetBoltDriver(boltDBParams)
@@ -41,6 +44,231 @@ func TestWrapperErrors(t *testing.T) {
repoMetaBlob, err := json.Marshal(repoMeta)
So(err, ShouldBeNil)
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: "test",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
Convey("AddUserAPIKey", func() {
Convey("no userid found", func() {
acCtx := localCtx.AccessControlContext{
Username: "",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err = boltdbWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{})
So(err, ShouldNotBeNil)
})
err = boltdbWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{})
So(err, ShouldNotBeNil)
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
return tx.DeleteBucket([]byte(bolt.UserDataBucket))
})
So(err, ShouldBeNil)
err = boltdbWrapper.AddUserAPIKey(ctx, "test", &repodb.APIKeyDetails{})
So(err, ShouldNotBeNil)
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
return tx.DeleteBucket([]byte(bolt.UserAPIKeysBucket))
})
So(err, ShouldBeNil)
err = boltdbWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{})
So(err, ShouldEqual, zerr.ErrBucketDoesNotExist)
})
Convey("UpdateUserAPIKey", func() {
err = boltdbWrapper.UpdateUserAPIKeyLastUsed(ctx, "")
So(err, ShouldNotBeNil)
acCtx := localCtx.AccessControlContext{
Username: "",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err = boltdbWrapper.UpdateUserAPIKeyLastUsed(ctx, "") //nolint: contextcheck
So(err, ShouldNotBeNil)
})
Convey("DeleteUserAPIKey", func() {
err = boltdbWrapper.SetUserData(ctx, repodb.UserData{})
So(err, ShouldBeNil)
err = boltdbWrapper.AddUserAPIKey(ctx, "hashedKey", &repodb.APIKeyDetails{})
So(err, ShouldBeNil)
Convey("no such bucket", func() {
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
return tx.DeleteBucket([]byte(bolt.UserAPIKeysBucket))
})
So(err, ShouldBeNil)
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: "test",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err = boltdbWrapper.DeleteUserAPIKey(ctx, "")
So(err, ShouldEqual, zerr.ErrBucketDoesNotExist)
})
Convey("userdata not found", func() {
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: "test",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err := boltdbWrapper.DeleteUserData(ctx)
So(err, ShouldBeNil)
err = boltdbWrapper.DeleteUserAPIKey(ctx, "")
So(err, ShouldNotBeNil)
})
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: "",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) //nolint: contextcheck
err = boltdbWrapper.DeleteUserAPIKey(ctx, "test") //nolint: contextcheck
So(err, ShouldNotBeNil)
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
return tx.DeleteBucket([]byte(bolt.UserDataBucket))
})
So(err, ShouldBeNil)
err = boltdbWrapper.DeleteUserAPIKey(ctx, "") //nolint: contextcheck
So(err, ShouldNotBeNil)
})
Convey("GetUserAPIKeyInfo", func() {
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
return tx.DeleteBucket([]byte(bolt.UserAPIKeysBucket))
})
So(err, ShouldBeNil)
_, err = boltdbWrapper.GetUserAPIKeyInfo("")
So(err, ShouldNotBeNil)
})
Convey("GetUserData", func() {
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
buck := tx.Bucket([]byte(bolt.UserDataBucket))
So(buck, ShouldNotBeNil)
return buck.Put([]byte("test"), []byte("dsa8"))
})
So(err, ShouldBeNil)
_, err = boltdbWrapper.GetUserData(ctx)
So(err, ShouldNotBeNil)
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
return tx.DeleteBucket([]byte(bolt.UserAPIKeysBucket))
})
So(err, ShouldBeNil)
_, err = boltdbWrapper.GetUserData(ctx)
So(err, ShouldNotBeNil)
})
Convey("SetUserData", func() {
acCtx = localCtx.AccessControlContext{
Username: "",
}
ctx = context.WithValue(context.Background(), authzCtxKey, acCtx)
err = boltdbWrapper.SetUserData(ctx, repodb.UserData{})
So(err, ShouldNotBeNil)
buff := make([]byte, int(math.Ceil(float64(1000000)/float64(1.33333333333))))
_, err := rand.Read(buff)
So(err, ShouldBeNil)
longString := base64.RawURLEncoding.EncodeToString(buff)
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: longString,
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err = boltdbWrapper.SetUserData(ctx, repodb.UserData{}) //nolint: contextcheck
So(err, ShouldNotBeNil)
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
return tx.DeleteBucket([]byte(bolt.UserDataBucket))
})
So(err, ShouldBeNil)
acCtx = localCtx.AccessControlContext{
Username: "test",
}
ctx = context.WithValue(context.Background(), authzCtxKey, acCtx)
err = boltdbWrapper.SetUserData(ctx, repodb.UserData{}) //nolint: contextcheck
So(err, ShouldNotBeNil)
})
Convey("DeleteUserData", func() {
acCtx = localCtx.AccessControlContext{
Username: "",
}
ctx = context.WithValue(context.Background(), authzCtxKey, acCtx)
err = boltdbWrapper.DeleteUserData(ctx)
So(err, ShouldNotBeNil)
err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
return tx.DeleteBucket([]byte(bolt.UserDataBucket))
})
So(err, ShouldBeNil)
acCtx = localCtx.AccessControlContext{
Username: "test",
}
ctx = context.WithValue(context.Background(), authzCtxKey, acCtx)
err = boltdbWrapper.DeleteUserData(ctx)
So(err, ShouldNotBeNil)
})
Convey("GetUserGroups and SetUserGroups", func() {
acCtx = localCtx.AccessControlContext{
Username: "",
}
ctx = context.WithValue(context.Background(), authzCtxKey, acCtx)
_, err := boltdbWrapper.GetUserGroups(ctx)
So(err, ShouldNotBeNil)
err = boltdbWrapper.SetUserGroups(ctx, []string{})
So(err, ShouldNotBeNil)
})
Convey("GetManifestData", func() {
err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
dataBuck := tx.Bucket([]byte(bolt.ManifestDataBucket))
@@ -732,60 +960,6 @@ func TestWrapperErrors(t *testing.T) {
So(err, ShouldNotBeNil)
})
Convey("ToggleStarRepo, getting StarredRepoKey from bucket fails", func() {
acCtx := localCtx.AccessControlContext{
ReadGlobPatterns: map[string]bool{
"repo": true,
},
Username: "username",
}
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
userdb, err := tx.CreateBucketIfNotExists([]byte(bolt.UserDataBucket))
So(err, ShouldBeNil)
userBucket, err := userdb.CreateBucketIfNotExists([]byte(acCtx.Username))
So(err, ShouldBeNil)
err = userBucket.Put([]byte(bolt.StarredReposKey), []byte("bad array"))
So(err, ShouldBeNil)
return nil
})
So(err, ShouldBeNil)
_, err = boltdbWrapper.ToggleStarRepo(ctx, "repo")
So(err, ShouldNotBeNil)
})
Convey("ToggleBookmarkRepo, unmarshal error", func() {
acCtx := localCtx.AccessControlContext{
ReadGlobPatterns: map[string]bool{
"repo": true,
},
Username: "username",
}
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
userdb, err := tx.CreateBucketIfNotExists([]byte(bolt.UserDataBucket))
So(err, ShouldBeNil)
userBucket, err := userdb.CreateBucketIfNotExists([]byte(acCtx.Username))
So(err, ShouldBeNil)
err = userBucket.Put([]byte(bolt.BookmarkedReposKey), []byte("bad array"))
So(err, ShouldBeNil)
return nil
})
So(err, ShouldBeNil)
_, err = boltdbWrapper.ToggleBookmarkRepo(ctx, "repo")
So(err, ShouldNotBeNil)
})
Convey("ToggleStarRepo, no repoMeta found", func() {
acCtx := localCtx.AccessControlContext{
ReadGlobPatterns: map[string]bool{
@@ -832,6 +1006,73 @@ func TestWrapperErrors(t *testing.T) {
So(err, ShouldNotBeNil)
})
Convey("GetUserData bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
_, err := boltdbWrapper.GetUserData(ctx)
So(err, ShouldNotBeNil)
})
Convey("SetUserData bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
err := boltdbWrapper.SetUserData(ctx, repodb.UserData{})
So(err, ShouldNotBeNil)
})
Convey("GetUserGroups bad context errors", func() {
_, err := boltdbWrapper.GetUserGroups(ctx)
So(err, ShouldNotBeNil)
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
_, err = boltdbWrapper.GetUserGroups(ctx) //nolint: contextcheck
So(err, ShouldNotBeNil)
})
Convey("SetUserGroups bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
err := boltdbWrapper.SetUserGroups(ctx, []string{})
So(err, ShouldNotBeNil)
})
Convey("AddUserAPIKey bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
err := boltdbWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{})
So(err, ShouldNotBeNil)
})
Convey("DeleteUserAPIKey bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
err := boltdbWrapper.DeleteUserAPIKey(ctx, "")
So(err, ShouldNotBeNil)
})
Convey("UpdateUserAPIKeyLastUsed bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
err := boltdbWrapper.UpdateUserAPIKeyLastUsed(ctx, "")
So(err, ShouldNotBeNil)
})
Convey("DeleteUserData bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
err := boltdbWrapper.DeleteUserData(ctx)
So(err, ShouldNotBeNil)
})
Convey("GetStarredRepos bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
@@ -840,60 +1081,6 @@ func TestWrapperErrors(t *testing.T) {
So(err, ShouldNotBeNil)
})
Convey("GetStarredRepos user data unmarshal error", func() {
acCtx := localCtx.AccessControlContext{
ReadGlobPatterns: map[string]bool{
"repo": true,
},
Username: "username",
}
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
userdb, err := tx.CreateBucketIfNotExists([]byte(bolt.UserDataBucket))
So(err, ShouldBeNil)
userBucket, err := userdb.CreateBucketIfNotExists([]byte(acCtx.Username))
So(err, ShouldBeNil)
err = userBucket.Put([]byte(bolt.StarredReposKey), []byte("bad array"))
So(err, ShouldBeNil)
return nil
})
So(err, ShouldBeNil)
_, err = boltdbWrapper.GetStarredRepos(ctx)
So(err, ShouldNotBeNil)
})
Convey("GetBookmarkedRepos user data unmarshal error", func() {
acCtx := localCtx.AccessControlContext{
ReadGlobPatterns: map[string]bool{
"repo": true,
},
Username: "username",
}
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error {
userdb, err := tx.CreateBucketIfNotExists([]byte(bolt.UserDataBucket))
So(err, ShouldBeNil)
userBucket, err := userdb.CreateBucketIfNotExists([]byte(acCtx.Username))
So(err, ShouldBeNil)
err = userBucket.Put([]byte(bolt.BookmarkedReposKey), []byte("bad array"))
So(err, ShouldBeNil)
return nil
})
So(err, ShouldBeNil)
_, err = boltdbWrapper.GetBookmarkedRepos(ctx)
So(err, ShouldNotBeNil)
})
Convey("GetBookmarkedRepos bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
@@ -31,6 +31,7 @@ func TestWrapperErrors(t *testing.T) {
manifestDataTablename := "ManifestDataTable" + uuid.String()
indexDataTablename := "IndexDataTable" + uuid.String()
userDataTablename := "UserDataTable" + uuid.String()
apiKeyTablename := "ApiKeyTable" + uuid.String()
versionTablename := "Version" + uuid.String()
@@ -58,6 +59,7 @@ func TestWrapperErrors(t *testing.T) {
IndexDataTablename: indexDataTablename,
VersionTablename: versionTablename,
UserDataTablename: userDataTablename,
APIKeyTablename: apiKeyTablename,
Patches: version.GetDynamoDBPatches(),
Log: log.Logger{Logger: zerolog.New(os.Stdout)},
}
@@ -74,6 +76,9 @@ func TestWrapperErrors(t *testing.T) {
err = dynamoWrapper.createVersionTable()
So(err, ShouldNotBeNil)
err = dynamoWrapper.createAPIKeyTable()
So(err, ShouldNotBeNil)
})
Convey("Delete table errors", t, func() {
+237 -14
View File
@@ -43,6 +43,7 @@ func TestIterator(t *testing.T) {
versionTablename := "Version" + uuid.String()
indexDataTablename := "IndexDataTable" + uuid.String()
userDataTablename := "UserDataTable" + uuid.String()
apiKeyTablename := "ApiKeyTable" + uuid.String()
log := log.NewLogger("debug", "")
@@ -54,6 +55,7 @@ func TestIterator(t *testing.T) {
ManifestDataTablename: manifestDataTablename,
IndexDataTablename: indexDataTablename,
VersionTablename: versionTablename,
APIKeyTablename: apiKeyTablename,
UserDataTablename: userDataTablename,
}
client, err := dynamo.GetDynamoClient(params)
@@ -144,8 +146,8 @@ func TestWrapperErrors(t *testing.T) {
versionTablename := "Version" + uuid.String()
indexDataTablename := "IndexDataTable" + uuid.String()
userDataTablename := "UserDataTable" + uuid.String()
ctx := context.Background()
apiKeyTablename := "ApiKeyTable" + uuid.String()
wrongTableName := "WRONG Tables"
log := log.NewLogger("debug", "")
@@ -157,6 +159,7 @@ func TestWrapperErrors(t *testing.T) {
ManifestDataTablename: manifestDataTablename,
IndexDataTablename: indexDataTablename,
UserDataTablename: userDataTablename,
APIKeyTablename: apiKeyTablename,
VersionTablename: versionTablename,
}
client, err := dynamo.GetDynamoClient(params) //nolint:contextcheck
@@ -168,6 +171,61 @@ func TestWrapperErrors(t *testing.T) {
So(dynamoWrapper.ResetManifestDataTable(), ShouldBeNil) //nolint:contextcheck
So(dynamoWrapper.ResetRepoMetaTable(), ShouldBeNil) //nolint:contextcheck
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: "test",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
Convey("SetUserData", func() {
hashKey := "id"
apiKeys := make(map[string]repodb.APIKeyDetails)
apiKeyDetails := repodb.APIKeyDetails{
Label: "apiKey",
Scopes: []string{"repo"},
UUID: hashKey,
}
apiKeys[hashKey] = apiKeyDetails
userProfileSrc := repodb.UserData{
Groups: []string{"group1", "group2"},
APIKeys: apiKeys,
}
err := dynamoWrapper.SetUserData(ctx, userProfileSrc)
So(err, ShouldBeNil)
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: "",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err = dynamoWrapper.SetUserData(ctx, repodb.UserData{}) //nolint: contextcheck
So(err, ShouldNotBeNil)
})
Convey("DeleteUserData", func() {
err := dynamoWrapper.DeleteUserData(ctx)
So(err, ShouldBeNil)
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: "",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err = dynamoWrapper.DeleteUserData(ctx) //nolint: contextcheck
So(err, ShouldNotBeNil)
})
Convey("ToggleBookmarkRepo no access", func() {
acCtx := localCtx.AccessControlContext{
ReadGlobPatterns: map[string]bool{
@@ -290,17 +348,17 @@ func TestWrapperErrors(t *testing.T) {
So(err, ShouldNotBeNil)
})
Convey("GetUserMeta bad context", func() {
Convey("GetUserData bad context", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
userData, err := dynamoWrapper.GetUserMeta(ctx)
userData, err := dynamoWrapper.GetUserData(ctx)
So(err, ShouldNotBeNil)
So(userData.BookmarkedRepos, ShouldBeEmpty)
So(userData.StarredRepos, ShouldBeEmpty)
})
Convey("GetUserMeta client error", func() {
Convey("GetUserData client error", func() {
acCtx := localCtx.AccessControlContext{
ReadGlobPatterns: map[string]bool{
"repo": true,
@@ -312,7 +370,7 @@ func TestWrapperErrors(t *testing.T) {
dynamoWrapper.UserDataTablename = badTablename
_, err := dynamoWrapper.GetUserMeta(ctx)
_, err := dynamoWrapper.GetUserData(ctx)
So(err, ShouldNotBeNil)
})
@@ -329,27 +387,155 @@ func TestWrapperErrors(t *testing.T) {
err := setBadUserData(dynamoWrapper.Client, userDataTablename, acCtx.Username)
So(err, ShouldBeNil)
_, err = dynamoWrapper.GetUserMeta(ctx)
_, err = dynamoWrapper.GetUserData(ctx)
So(err, ShouldNotBeNil)
})
Convey("SetUserMeta bad context", func() {
Convey("SetUserData bad context", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
err := dynamoWrapper.SetUserMeta(ctx, repodb.UserData{})
err := dynamoWrapper.SetUserData(ctx, repodb.UserData{})
So(err, ShouldNotBeNil)
})
Convey("GetUserData bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
_, err := dynamoWrapper.GetUserData(ctx)
So(err, ShouldNotBeNil)
})
Convey("SetUserData bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
err := dynamoWrapper.SetUserData(ctx, repodb.UserData{})
So(err, ShouldNotBeNil)
})
Convey("AddUserAPIKey bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
err := dynamoWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{})
So(err, ShouldNotBeNil)
})
Convey("DeleteUserAPIKey bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
err := dynamoWrapper.DeleteUserAPIKey(ctx, "")
So(err, ShouldNotBeNil)
})
Convey("UpdateUserAPIKeyLastUsed bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
err := dynamoWrapper.UpdateUserAPIKeyLastUsed(ctx, "")
So(err, ShouldNotBeNil)
})
Convey("DeleteUserData bad context errors", func() {
authzCtxKey := localCtx.GetContextKey()
ctx := context.WithValue(context.Background(), authzCtxKey, "bad context")
err := dynamoWrapper.DeleteUserData(ctx)
So(err, ShouldNotBeNil)
})
Convey("DeleteUserAPIKey returns nil", func() {
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: "email",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
apiKeyDetails := make(map[string]repodb.APIKeyDetails)
apiKeyDetails["id"] = repodb.APIKeyDetails{
UUID: "id",
}
err := dynamoWrapper.SetUserData(ctx, repodb.UserData{
APIKeys: apiKeyDetails,
})
So(err, ShouldBeNil)
dynamoWrapper.APIKeyTablename = wrongTableName
err = dynamoWrapper.DeleteUserAPIKey(ctx, "id")
So(err, ShouldNotBeNil)
})
Convey("AddUserAPIKey", func() {
Convey("no userid found", func() {
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: "",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err = dynamoWrapper.AddUserAPIKey(ctx, "key", &repodb.APIKeyDetails{})
So(err, ShouldNotBeNil)
})
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: "email",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err := dynamoWrapper.AddUserAPIKey(ctx, "key", &repodb.APIKeyDetails{})
So(err, ShouldBeNil)
dynamoWrapper.APIKeyTablename = wrongTableName
err = dynamoWrapper.AddUserAPIKey(ctx, "key", &repodb.APIKeyDetails{})
So(err, ShouldNotBeNil)
})
Convey("GetUserAPIKeyInfo", func() {
dynamoWrapper.APIKeyTablename = wrongTableName
_, err := dynamoWrapper.GetUserAPIKeyInfo("key")
So(err, ShouldNotBeNil)
})
Convey("GetUserData", func() {
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: "",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
_, err := dynamoWrapper.GetUserData(ctx)
So(err, ShouldNotBeNil)
acCtx = localCtx.AccessControlContext{
Username: "email",
}
ctx = context.WithValue(context.Background(), authzCtxKey, acCtx)
dynamoWrapper.UserDataTablename = wrongTableName
_, err = dynamoWrapper.GetUserData(ctx)
So(err, ShouldNotBeNil)
})
Convey("SetManifestData", func() {
dynamoWrapper.ManifestDataTablename = "WRONG tables"
dynamoWrapper.ManifestDataTablename = wrongTableName
err := dynamoWrapper.SetManifestData("dig", repodb.ManifestData{})
So(err, ShouldNotBeNil)
})
Convey("GetManifestData", func() {
dynamoWrapper.ManifestDataTablename = "WRONG table"
dynamoWrapper.ManifestDataTablename = wrongTableName
_, err := dynamoWrapper.GetManifestData("dig")
So(err, ShouldNotBeNil)
@@ -364,7 +550,7 @@ func TestWrapperErrors(t *testing.T) {
})
Convey("GetIndexData", func() {
dynamoWrapper.IndexDataTablename = "WRONG table"
dynamoWrapper.IndexDataTablename = wrongTableName
_, err := dynamoWrapper.GetIndexData("dig")
So(err, ShouldNotBeNil)
@@ -1091,6 +1277,7 @@ func TestWrapperErrors(t *testing.T) {
ManifestDataTablename: manifestDataTablename,
IndexDataTablename: indexDataTablename,
UserDataTablename: userDataTablename,
APIKeyTablename: apiKeyTablename,
VersionTablename: versionTablename,
}
client, err := dynamo.GetDynamoClient(params)
@@ -1106,6 +1293,7 @@ func TestWrapperErrors(t *testing.T) {
ManifestDataTablename: "",
IndexDataTablename: indexDataTablename,
UserDataTablename: userDataTablename,
APIKeyTablename: apiKeyTablename,
VersionTablename: versionTablename,
}
client, err = dynamo.GetDynamoClient(params)
@@ -1121,6 +1309,7 @@ func TestWrapperErrors(t *testing.T) {
ManifestDataTablename: manifestDataTablename,
IndexDataTablename: "",
UserDataTablename: userDataTablename,
APIKeyTablename: apiKeyTablename,
VersionTablename: versionTablename,
}
client, err = dynamo.GetDynamoClient(params)
@@ -1136,6 +1325,7 @@ func TestWrapperErrors(t *testing.T) {
ManifestDataTablename: manifestDataTablename,
IndexDataTablename: indexDataTablename,
UserDataTablename: userDataTablename,
APIKeyTablename: apiKeyTablename,
VersionTablename: "",
}
client, err = dynamo.GetDynamoClient(params)
@@ -1150,8 +1340,41 @@ func TestWrapperErrors(t *testing.T) {
RepoMetaTablename: repoMetaTablename,
ManifestDataTablename: manifestDataTablename,
IndexDataTablename: indexDataTablename,
UserDataTablename: "",
VersionTablename: versionTablename,
UserDataTablename: userDataTablename,
APIKeyTablename: apiKeyTablename,
}
client, err = dynamo.GetDynamoClient(params)
So(err, ShouldBeNil)
_, err = dynamoWrapper.NewDynamoDBWrapper(client, params, log)
So(err, ShouldBeNil)
params = dynamo.DBDriverParameters{ //nolint:contextcheck
Endpoint: endpoint,
Region: region,
RepoMetaTablename: repoMetaTablename,
ManifestDataTablename: manifestDataTablename,
IndexDataTablename: indexDataTablename,
VersionTablename: versionTablename,
UserDataTablename: "",
APIKeyTablename: apiKeyTablename,
}
client, err = dynamo.GetDynamoClient(params)
So(err, ShouldBeNil)
_, err = dynamoWrapper.NewDynamoDBWrapper(client, params, log)
So(err, ShouldNotBeNil)
params = dynamo.DBDriverParameters{ //nolint:contextcheck
Endpoint: endpoint,
Region: region,
RepoMetaTablename: repoMetaTablename,
ManifestDataTablename: manifestDataTablename,
IndexDataTablename: indexDataTablename,
VersionTablename: versionTablename,
UserDataTablename: userDataTablename,
APIKeyTablename: "",
}
client, err = dynamo.GetDynamoClient(params)
So(err, ShouldBeNil)
@@ -1250,7 +1473,7 @@ func setBadUserData(client *dynamodb.Client, userDataTablename, userID string) e
":UserData": userAttributeValue,
},
Key: map[string]types.AttributeValue{
"UserID": &types.AttributeValueMemberS{
"Identity": &types.AttributeValueMemberS{
Value: userID,
},
},
@@ -30,6 +30,7 @@ var errRepodb = errors.New("repodb: error while constructing manifest meta")
type DBWrapper struct {
Client *dynamodb.Client
APIKeyTablename string
RepoMetaTablename string
IndexDataTablename string
ManifestDataTablename string
@@ -47,6 +48,7 @@ func NewDynamoDBWrapper(client *dynamodb.Client, params dynamo.DBDriverParameter
IndexDataTablename: params.IndexDataTablename,
VersionTablename: params.VersionTablename,
UserDataTablename: params.UserDataTablename,
APIKeyTablename: params.APIKeyTablename,
Patches: version.GetDynamoDBPatches(),
Log: log,
}
@@ -76,6 +78,11 @@ func NewDynamoDBWrapper(client *dynamodb.Client, params dynamo.DBDriverParameter
return nil, err
}
err = dynamoWrapper.createAPIKeyTable()
if err != nil {
return nil, err
}
// Using the Config value, create the DynamoDB client
return &dynamoWrapper, nil
}
@@ -580,13 +587,13 @@ func (dwr *DBWrapper) GetUserRepoMeta(ctx context.Context, repo string) (repodb.
return repodb.RepoMetadata{}, err
}
userMeta, err := dwr.GetUserMeta(ctx)
userData, err := dwr.GetUserData(ctx)
if err != nil {
return repodb.RepoMetadata{}, err
}
repoMeta.IsBookmarked = zcommon.Contains(userMeta.BookmarkedRepos, repo)
repoMeta.IsStarred = zcommon.Contains(userMeta.StarredRepos, repo)
repoMeta.IsBookmarked = zcommon.Contains(userData.BookmarkedRepos, repo)
repoMeta.IsStarred = zcommon.Contains(userData.StarredRepos, repo)
return repoMeta, nil
}
@@ -1802,7 +1809,7 @@ func (dwr *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (
return res, zerr.ErrUserDataNotAllowed
}
userMeta, err := dwr.GetUserMeta(ctx)
userData, err := dwr.GetUserData(ctx)
if err != nil {
if errors.Is(err, zerr.ErrUserDataNotFound) {
return repodb.NotChanged, nil
@@ -1811,16 +1818,16 @@ func (dwr *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (
return res, err
}
if !zcommon.Contains(userMeta.BookmarkedRepos, repo) {
userMeta.BookmarkedRepos = append(userMeta.BookmarkedRepos, repo)
if !zcommon.Contains(userData.BookmarkedRepos, repo) {
userData.BookmarkedRepos = append(userData.BookmarkedRepos, repo)
res = repodb.Added
} else {
userMeta.BookmarkedRepos = zcommon.RemoveFrom(userMeta.BookmarkedRepos, repo)
userData.BookmarkedRepos = zcommon.RemoveFrom(userData.BookmarkedRepos, repo)
res = repodb.Removed
}
if res != repodb.NotChanged {
err = dwr.SetUserMeta(ctx, userMeta)
err = dwr.SetUserData(ctx, userData)
}
if err != nil {
@@ -1833,9 +1840,9 @@ func (dwr *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (
}
func (dwr *DBWrapper) GetBookmarkedRepos(ctx context.Context) ([]string, error) {
userMeta, err := dwr.GetUserMeta(ctx)
userMeta, err := dwr.GetUserData(ctx)
if errors.Is(err, zerr.ErrUserDataNotFound) {
if errors.Is(err, zerr.ErrUserDataNotFound) || errors.Is(err, zerr.ErrUserDataNotAllowed) {
return []string{}, nil
}
@@ -1863,7 +1870,7 @@ func (dwr *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) (
return res, zerr.ErrUserDataNotAllowed
}
userData, err := dwr.GetUserMeta(ctx)
userData, err := dwr.GetUserData(ctx)
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
return res, err
}
@@ -1902,21 +1909,21 @@ func (dwr *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) (
_, err = dwr.Client.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{
TransactItems: []types.TransactWriteItem{
{
// Update User Meta
// Update User Profile
Update: &types.Update{
ExpressionAttributeNames: map[string]string{
"#UM": "UserData",
"#UP": "UserData",
},
ExpressionAttributeValues: map[string]types.AttributeValue{
":UserData": userAttributeValue,
},
Key: map[string]types.AttributeValue{
"UserID": &types.AttributeValueMemberS{
"Identity": &types.AttributeValueMemberS{
Value: userid,
},
},
TableName: aws.String(dwr.UserDataTablename),
UpdateExpression: aws.String("SET #UM = :UserData"),
UpdateExpression: aws.String("SET #UP = :UserData"),
},
},
{
@@ -1948,64 +1955,27 @@ func (dwr *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) (
}
func (dwr *DBWrapper) GetStarredRepos(ctx context.Context) ([]string, error) {
userMeta, err := dwr.GetUserMeta(ctx)
userMeta, err := dwr.GetUserData(ctx)
if errors.Is(err, zerr.ErrUserDataNotFound) {
if errors.Is(err, zerr.ErrUserDataNotFound) || errors.Is(err, zerr.ErrUserDataNotAllowed) {
return []string{}, nil
}
return userMeta.StarredRepos, err
}
func (dwr *DBWrapper) GetUserMeta(ctx context.Context) (repodb.UserData, error) {
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return repodb.UserData{}, err
}
userid := localCtx.GetUsernameFromContext(acCtx)
if userid == "" {
// empty user is anonymous, it has no data
return repodb.UserData{}, nil
}
resp, err := dwr.Client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(dwr.UserDataTablename),
Key: map[string]types.AttributeValue{
"UserID": &types.AttributeValueMemberS{Value: userid},
},
})
if err != nil {
return repodb.UserData{}, err
}
if resp.Item == nil {
return repodb.UserData{}, zerr.ErrUserDataNotFound
}
var userMeta repodb.UserData
err = attributevalue.Unmarshal(resp.Item["UserData"], &userMeta)
if err != nil {
return repodb.UserData{}, err
}
return userMeta, nil
}
func (dwr *DBWrapper) createUserDataTable() error {
_, err := dwr.Client.CreateTable(context.Background(), &dynamodb.CreateTableInput{
TableName: aws.String(dwr.UserDataTablename),
AttributeDefinitions: []types.AttributeDefinition{
{
AttributeName: aws.String("UserID"),
AttributeName: aws.String("Identity"),
AttributeType: types.ScalarAttributeTypeS,
},
},
KeySchema: []types.KeySchemaElement{
{
AttributeName: aws.String("UserID"),
AttributeName: aws.String("Identity"),
KeyType: types.KeyTypeHash,
},
},
@@ -2019,38 +1989,279 @@ func (dwr *DBWrapper) createUserDataTable() error {
return dwr.waitTableToBeCreated(dwr.UserDataTablename)
}
func (dwr *DBWrapper) SetUserMeta(ctx context.Context, userMeta repodb.UserData) error {
func (dwr DBWrapper) createAPIKeyTable() error {
_, err := dwr.Client.CreateTable(context.Background(), &dynamodb.CreateTableInput{
TableName: aws.String(dwr.APIKeyTablename),
AttributeDefinitions: []types.AttributeDefinition{
{
AttributeName: aws.String("HashedKey"),
AttributeType: types.ScalarAttributeTypeS,
},
},
KeySchema: []types.KeySchemaElement{
{
AttributeName: aws.String("HashedKey"),
KeyType: types.KeyTypeHash,
},
},
BillingMode: types.BillingModePayPerRequest,
})
if err != nil && !strings.Contains(err.Error(), "Table already exists") {
return err
}
return dwr.waitTableToBeCreated(dwr.APIKeyTablename)
}
func (dwr DBWrapper) SetUserGroups(ctx context.Context, groups []string) error {
userData, err := dwr.GetUserData(ctx)
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
return err
}
userData.Groups = append(userData.Groups, groups...)
return dwr.SetUserData(ctx, userData)
}
func (dwr DBWrapper) GetUserGroups(ctx context.Context) ([]string, error) {
userData, err := dwr.GetUserData(ctx)
return userData.Groups, err
}
func (dwr DBWrapper) UpdateUserAPIKeyLastUsed(ctx context.Context, hashedKey string) error {
userData, err := dwr.GetUserData(ctx)
if err != nil {
return err
}
apiKeyDetails := userData.APIKeys[hashedKey]
apiKeyDetails.LastUsed = time.Now()
userData.APIKeys[hashedKey] = apiKeyDetails
err = dwr.SetUserData(ctx, userData)
return err
}
func (dwr DBWrapper) AddUserAPIKey(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error {
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return err
}
userid := localCtx.GetUsernameFromContext(acCtx)
if userid == "" {
// empty user is anonymous, it has no data
// empty user is anonymous
return zerr.ErrUserDataNotAllowed
}
userAttributeValue, err := attributevalue.Marshal(userMeta)
userData, err := dwr.GetUserData(ctx)
if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) {
return fmt.Errorf("repoDB: error while getting userData for identity %s %w", userid, err)
}
if userData.APIKeys == nil {
userData.APIKeys = make(map[string]repodb.APIKeyDetails)
}
userData.APIKeys[hashedKey] = *apiKeyDetails
userAttributeValue, err := attributevalue.Marshal(userData)
if err != nil {
return err
}
_, err = dwr.Client.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{
TransactItems: []types.TransactWriteItem{
{
// Update UserData
Update: &types.Update{
ExpressionAttributeNames: map[string]string{
"#UP": "UserData",
},
ExpressionAttributeValues: map[string]types.AttributeValue{
":UserData": userAttributeValue,
},
Key: map[string]types.AttributeValue{
"Identity": &types.AttributeValueMemberS{
Value: userid,
},
},
TableName: aws.String(dwr.UserDataTablename),
UpdateExpression: aws.String("SET #UP = :UserData"),
},
},
{
// Update APIKeyInfo
Update: &types.Update{
ExpressionAttributeNames: map[string]string{
"#EM": "Identity",
},
ExpressionAttributeValues: map[string]types.AttributeValue{
":Identity": &types.AttributeValueMemberS{Value: userid},
},
Key: map[string]types.AttributeValue{
"HashedKey": &types.AttributeValueMemberS{
Value: hashedKey,
},
},
TableName: aws.String(dwr.APIKeyTablename),
UpdateExpression: aws.String("SET #EM = :Identity"),
},
},
},
})
return err
}
func (dwr DBWrapper) DeleteUserAPIKey(ctx context.Context, keyID string) error {
userData, err := dwr.GetUserData(ctx)
if err != nil {
return fmt.Errorf("repoDB: error while getting userData %w", err)
}
for hash, apiKeyDetails := range userData.APIKeys {
if apiKeyDetails.UUID == keyID {
delete(userData.APIKeys, hash)
_, err = dwr.Client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
TableName: aws.String(dwr.APIKeyTablename),
Key: map[string]types.AttributeValue{
"HashedKey": &types.AttributeValueMemberS{Value: hash},
},
})
if err != nil {
return fmt.Errorf("repoDB: error while deleting userAPIKey entry for hash %s %w", hash, err)
}
err := dwr.SetUserData(ctx, userData)
return err
}
}
return nil
}
func (dwr DBWrapper) GetUserAPIKeyInfo(hashedKey string) (string, error) {
var userid string
resp, err := dwr.Client.GetItem(context.Background(), &dynamodb.GetItemInput{
TableName: aws.String(dwr.APIKeyTablename),
Key: map[string]types.AttributeValue{
"HashedKey": &types.AttributeValueMemberS{Value: hashedKey},
},
})
if err != nil {
return "", err
}
if resp.Item == nil {
return "", zerr.ErrUserAPIKeyNotFound
}
err = attributevalue.Unmarshal(resp.Item["Identity"], &userid)
if err != nil {
return "", err
}
return userid, nil
}
func (dwr DBWrapper) GetUserData(ctx context.Context) (repodb.UserData, error) {
var userData repodb.UserData
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return userData, err
}
userid := localCtx.GetUsernameFromContext(acCtx)
if userid == "" {
// empty user is anonymous
return userData, zerr.ErrUserDataNotAllowed
}
resp, err := dwr.Client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(dwr.UserDataTablename),
Key: map[string]types.AttributeValue{
"Identity": &types.AttributeValueMemberS{Value: userid},
},
})
if err != nil {
return repodb.UserData{}, err
}
if resp.Item == nil {
return repodb.UserData{}, zerr.ErrUserDataNotFound
}
err = attributevalue.Unmarshal(resp.Item["UserData"], &userData)
if err != nil {
return repodb.UserData{}, err
}
return userData, nil
}
func (dwr DBWrapper) SetUserData(ctx context.Context, userData repodb.UserData) error {
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return err
}
userid := localCtx.GetUsernameFromContext(acCtx)
if userid == "" {
// empty user is anonymous
return zerr.ErrUserDataNotAllowed
}
userAttributeValue, err := attributevalue.Marshal(userData)
if err != nil {
return err
}
_, err = dwr.Client.UpdateItem(ctx, &dynamodb.UpdateItemInput{
ExpressionAttributeNames: map[string]string{
"#UM": "UserData",
"#UP": "UserData",
},
ExpressionAttributeValues: map[string]types.AttributeValue{
":UserData": userAttributeValue,
},
Key: map[string]types.AttributeValue{
"UserID": &types.AttributeValueMemberS{
"Identity": &types.AttributeValueMemberS{
Value: userid,
},
},
TableName: aws.String(dwr.UserDataTablename),
UpdateExpression: aws.String("SET #UM = :UserData"),
UpdateExpression: aws.String("SET #UP = :UserData"),
})
return err
}
func (dwr DBWrapper) DeleteUserData(ctx context.Context) error {
acCtx, err := localCtx.GetAccessControlContext(ctx)
if err != nil {
return err
}
userid := localCtx.GetUsernameFromContext(acCtx)
if userid == "" {
// empty user is anonymous
return zerr.ErrUserDataNotAllowed
}
_, err = dwr.Client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
TableName: aws.String(dwr.UserDataTablename),
Key: map[string]types.AttributeValue{
"Identity": &types.AttributeValueMemberS{Value: userid},
},
})
return err
+40 -7
View File
@@ -24,6 +24,7 @@ type (
)
type RepoDB interface { //nolint:interfacebloat
UserDB
// IncrementRepoStars adds 1 to the star count of an image
IncrementRepoStars(repo string) error
@@ -111,6 +112,10 @@ type RepoDB interface { //nolint:interfacebloat
FilterTags(ctx context.Context, filter FilterFunc,
requestedPage PageInput) ([]RepoMetadata, map[string]ManifestMetadata, map[string]IndexData, common.PageInfo, error)
PatchDB() error
}
type UserDB interface { //nolint:interfacebloat
// GetStarredRepos returns starred repos and takes current user in consideration
GetStarredRepos(ctx context.Context) ([]string, error)
@@ -123,7 +128,24 @@ type RepoDB interface { //nolint:interfacebloat
// ToggleBookmarkRepo adds/removes bookmarks on repos
ToggleBookmarkRepo(ctx context.Context, reponame string) (ToggleState, error)
PatchDB() error
// UserDB profile/api key CRUD
GetUserData(ctx context.Context) (UserData, error)
SetUserData(ctx context.Context, userData UserData) error
SetUserGroups(ctx context.Context, groups []string) error
GetUserGroups(ctx context.Context) ([]string, error)
DeleteUserData(ctx context.Context) error
GetUserAPIKeyInfo(hashedKey string) (identity string, err error)
AddUserAPIKey(ctx context.Context, hashedKey string, apiKeyDetails *APIKeyDetails) error
UpdateUserAPIKeyLastUsed(ctx context.Context, hashedKey string) error
DeleteUserAPIKey(ctx context.Context, id string) error
}
type ManifestMetadata struct {
@@ -195,12 +217,6 @@ type SignatureMetadata struct {
LayersInfo []LayerInfo
}
type UserData struct {
// data for each user.
StarredRepos []string
BookmarkedRepos []string
}
type SortCriteria string
const (
@@ -235,3 +251,20 @@ type FilterData struct {
IsStarred bool
IsBookmarked bool
}
type UserData struct {
StarredRepos []string
BookmarkedRepos []string
Groups []string
APIKeys map[string]APIKeyDetails
}
type APIKeyDetails struct {
CreatedAt time.Time `json:"createdAt"`
CreatorUA string `json:"creatorUa"`
GeneratedBy string `json:"generatedBy"`
LastUsed time.Time `json:"lastUsed"`
Label string `json:"label"`
Scopes []string `json:"scopes"`
UUID string `json:"uuid"`
}
+114
View File
@@ -92,6 +92,7 @@ func TestDynamoDBWrapper(t *testing.T) {
versionTablename := "Version" + uuid.String()
indexDataTablename := "IndexDataTable" + uuid.String()
userDataTablename := "UserDataTable" + uuid.String()
apiKeyTablename := "ApiKeyTable" + uuid.String()
Convey("DynamoDB Wrapper", t, func() {
dynamoDBDriverParams := dynamo.DBDriverParameters{
@@ -101,6 +102,7 @@ func TestDynamoDBWrapper(t *testing.T) {
IndexDataTablename: indexDataTablename,
VersionTablename: versionTablename,
UserDataTablename: userDataTablename,
APIKeyTablename: apiKeyTablename,
Region: "us-east-2",
}
@@ -137,6 +139,118 @@ func RunRepoDBTests(t *testing.T, repoDB repodb.RepoDB, preparationFuncs ...func
So(err, ShouldBeNil)
}
Convey("Test CRUD operations on UserData and API keys", func() {
hashKey1 := "id"
hashKey2 := "key"
apiKeys := make(map[string]repodb.APIKeyDetails)
apiKeyDetails := repodb.APIKeyDetails{
Label: "apiKey",
Scopes: []string{"repo"},
UUID: hashKey1,
}
apiKeys[hashKey1] = apiKeyDetails
userProfileSrc := repodb.UserData{
Groups: []string{"group1", "group2"},
APIKeys: apiKeys,
}
authzCtxKey := localCtx.GetContextKey()
acCtx := localCtx.AccessControlContext{
Username: "test",
}
ctx := context.WithValue(context.Background(), authzCtxKey, acCtx)
err := repoDB.AddUserAPIKey(ctx, hashKey1, &apiKeyDetails)
So(err, ShouldBeNil)
err = repoDB.SetUserData(ctx, userProfileSrc)
So(err, ShouldBeNil)
userProfile, err := repoDB.GetUserData(ctx)
So(err, ShouldBeNil)
So(userProfile.Groups, ShouldResemble, userProfileSrc.Groups)
So(userProfile.APIKeys, ShouldContainKey, hashKey1)
So(userProfile.APIKeys[hashKey1].Label, ShouldEqual, apiKeyDetails.Label)
So(userProfile.APIKeys[hashKey1].Scopes, ShouldResemble, apiKeyDetails.Scopes)
lastUsed := userProfile.APIKeys[hashKey1].LastUsed
err = repoDB.UpdateUserAPIKeyLastUsed(ctx, hashKey1)
So(err, ShouldBeNil)
userProfile, err = repoDB.GetUserData(ctx)
So(err, ShouldBeNil)
So(userProfile.APIKeys[hashKey1].LastUsed, ShouldHappenAfter, lastUsed)
userGroups, err := repoDB.GetUserGroups(ctx)
So(err, ShouldBeNil)
So(userGroups, ShouldResemble, userProfileSrc.Groups)
apiKeyDetails.UUID = hashKey2
err = repoDB.AddUserAPIKey(ctx, hashKey2, &apiKeyDetails)
So(err, ShouldBeNil)
userProfile, err = repoDB.GetUserData(ctx)
So(err, ShouldBeNil)
So(userProfile.Groups, ShouldResemble, userProfileSrc.Groups)
So(userProfile.APIKeys, ShouldContainKey, hashKey2)
So(userProfile.APIKeys[hashKey2].Label, ShouldEqual, apiKeyDetails.Label)
So(userProfile.APIKeys[hashKey2].Scopes, ShouldResemble, apiKeyDetails.Scopes)
email, err := repoDB.GetUserAPIKeyInfo(hashKey2)
So(err, ShouldBeNil)
So(email, ShouldEqual, "test")
err = repoDB.DeleteUserAPIKey(ctx, hashKey1)
So(err, ShouldBeNil)
userProfile, err = repoDB.GetUserData(ctx)
So(err, ShouldBeNil)
So(len(userProfile.APIKeys), ShouldEqual, 1)
So(userProfile.APIKeys, ShouldNotContainKey, hashKey1)
err = repoDB.DeleteUserAPIKey(ctx, hashKey2)
So(err, ShouldBeNil)
userProfile, err = repoDB.GetUserData(ctx)
So(err, ShouldBeNil)
So(len(userProfile.APIKeys), ShouldEqual, 0)
So(userProfile.APIKeys, ShouldNotContainKey, hashKey2)
// delete non existent api key
err = repoDB.DeleteUserAPIKey(ctx, hashKey2)
So(err, ShouldBeNil)
err = repoDB.DeleteUserData(ctx)
So(err, ShouldBeNil)
email, err = repoDB.GetUserAPIKeyInfo(hashKey2)
So(err, ShouldNotBeNil)
So(email, ShouldBeEmpty)
email, err = repoDB.GetUserAPIKeyInfo(hashKey1)
So(err, ShouldNotBeNil)
So(email, ShouldBeEmpty)
_, err = repoDB.GetUserData(ctx)
So(err, ShouldNotBeNil)
userGroups, err = repoDB.GetUserGroups(ctx)
So(err, ShouldNotBeNil)
So(userGroups, ShouldBeEmpty)
err = repoDB.SetUserGroups(ctx, userProfileSrc.Groups)
So(err, ShouldBeNil)
userGroups, err = repoDB.GetUserGroups(ctx)
So(err, ShouldBeNil)
So(userGroups, ShouldResemble, userProfileSrc.Groups)
})
Convey("Test SetManifestData and GetManifestData", func() {
configBlob, manifestBlob, err := generateTestImage()
So(err, ShouldBeNil)
@@ -95,6 +95,9 @@ func getDynamoParams(cacheDriverConfig map[string]interface{}, log log.Logger) d
indexDataTablename, ok := toStringIfOk(cacheDriverConfig, "indexdatatablename", log)
allParametersOk = allParametersOk && ok
apiKeyTablename, ok := toStringIfOk(cacheDriverConfig, "apikeytablename", log)
allParametersOk = allParametersOk && ok
versionTablename, ok := toStringIfOk(cacheDriverConfig, "versiontablename", log)
allParametersOk = allParametersOk && ok
@@ -112,6 +115,7 @@ func getDynamoParams(cacheDriverConfig map[string]interface{}, log log.Logger) d
ManifestDataTablename: manifestDataTablename,
IndexDataTablename: indexDataTablename,
UserDataTablename: userDataTablename,
APIKeyTablename: apiKeyTablename,
VersionTablename: versionTablename,
}
}
@@ -25,6 +25,7 @@ func TestCreateDynamo(t *testing.T) {
ManifestDataTablename: "ManifestDataTable",
IndexDataTablename: "IndexDataTable",
UserDataTablename: "UserDataTable",
APIKeyTablename: "ApiKeyTable",
VersionTablename: "Version",
Region: "us-east-2",
}
+1
View File
@@ -390,6 +390,7 @@ func TestParseStorageDynamoWrapper(t *testing.T) {
ManifestDataTablename: "ManifestDataTable",
IndexDataTablename: "IndexDataTable",
UserDataTablename: "UserDataTable",
APIKeyTablename: "ApiKeyTable",
VersionTablename: "Version",
}
+1
View File
@@ -127,6 +127,7 @@ func TestVersioningDynamoDB(t *testing.T) {
ManifestDataTablename: "ManifestDataTable",
IndexDataTablename: "IndexDataTable",
UserDataTablename: "UserDataTable",
APIKeyTablename: "ApiKeyTable",
VersionTablename: "Version",
}
+26
View File
@@ -92,3 +92,29 @@ func (acCtx *AccessControlContext) matchesRepo(globPatterns map[string]bool, rep
return allowed
}
// request-local context key.
var amwCtxKey = Key(1) //nolint: gochecknoglobals
// pointer needed for use in context.WithValue.
func GetAuthnMiddlewareCtxKey() *Key {
return &amwCtxKey
}
type AuthnMiddlewareContext struct {
AuthnType string
}
func GetAuthnMiddlewareContext(ctx context.Context) (*AuthnMiddlewareContext, error) {
authnMiddlewareCtxKey := GetAuthnMiddlewareCtxKey()
if authnMiddlewareCtx := ctx.Value(authnMiddlewareCtxKey); authnMiddlewareCtx != nil {
amCtx, ok := authnMiddlewareCtx.(AuthnMiddlewareContext)
if !ok {
return nil, errors.ErrBadType
}
return &amCtx, nil
}
return nil, nil //nolint: nilnil
}
+54
View File
@@ -15,6 +15,7 @@ import (
"log"
"math"
"math/big"
"net"
"net/http"
"net/url"
"os"
@@ -38,6 +39,7 @@ import (
ispec "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/opencontainers/umoci"
"github.com/phayes/freeport"
"github.com/project-zot/mockoidc"
"github.com/sigstore/cosign/v2/cmd/cosign/cli/generate"
"github.com/sigstore/cosign/v2/cmd/cosign/cli/options"
"github.com/sigstore/cosign/v2/cmd/cosign/cli/sign"
@@ -1967,3 +1969,55 @@ func GetIndexBlobWithManifests(manifestDigests []godigest.Digest) ([]byte, error
return json.Marshal(indexContent)
}
func MockOIDCRun() (*mockoidc.MockOIDC, error) {
// Create a fresh RSA Private Key for token signing
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) //nolint: gomnd
// Create an unstarted MockOIDC server
mockServer, _ := mockoidc.NewServer(rsaKey)
// Create the net.Listener, kernel will chose a valid port
listener, _ := net.Listen("tcp", "127.0.0.1:0")
bearerMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, req *http.Request) {
// stateVal := req.Form.Get("state")
header := req.Header.Get("Authorization")
parts := strings.SplitN(header, " ", 2) //nolint: gomnd
if header != "" {
if strings.ToLower(parts[0]) == "bearer" {
req.Header.Set("Authorization", strings.Join([]string{"Bearer", parts[1]}, " "))
}
}
next.ServeHTTP(response, req)
})
}
err := mockServer.AddMiddleware(bearerMiddleware)
if err != nil {
return mockServer, err
}
// tlsConfig can be nil if you want HTTP
return mockServer, mockServer.Start(listener, nil)
}
func CustomRedirectPolicy(noOfRedirect int) resty.RedirectPolicy {
return resty.RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
if len(via) >= noOfRedirect {
return fmt.Errorf("stopped after %d redirects", noOfRedirect) //nolint: goerr113
}
for key, val := range via[len(via)-1].Header {
req.Header[key] = val
}
respCookies := req.Response.Cookies()
for _, cookie := range respCookies {
req.AddCookie(cookie)
}
return nil
})
}
+90
View File
@@ -95,6 +95,24 @@ type RepoDBMock struct {
ToggleBookmarkRepoFn func(ctx context.Context, repo string) (repodb.ToggleState, error)
GetUserDataFn func(ctx context.Context) (repodb.UserData, error)
SetUserDataFn func(ctx context.Context, userProfile repodb.UserData) error
SetUserGroupsFn func(ctx context.Context, groups []string) error
GetUserGroupsFn func(ctx context.Context) ([]string, error)
DeleteUserDataFn func(ctx context.Context) error
GetUserAPIKeyInfoFn func(hashedKey string) (string, error)
AddUserAPIKeyFn func(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error
UpdateUserAPIKeyLastUsedFn func(ctx context.Context, hashedKey string) error
DeleteUserAPIKeyFn func(ctx context.Context, id string) error
PatchDBFn func() error
}
@@ -414,3 +432,75 @@ func (sdm RepoDBMock) ToggleBookmarkRepo(ctx context.Context, repo string) (repo
return repodb.NotChanged, nil
}
func (sdm RepoDBMock) GetUserData(ctx context.Context) (repodb.UserData, error) {
if sdm.GetUserDataFn != nil {
return sdm.GetUserDataFn(ctx)
}
return repodb.UserData{}, nil
}
func (sdm RepoDBMock) SetUserData(ctx context.Context, userProfile repodb.UserData) error {
if sdm.SetUserDataFn != nil {
return sdm.SetUserDataFn(ctx, userProfile)
}
return nil
}
func (sdm RepoDBMock) SetUserGroups(ctx context.Context, groups []string) error {
if sdm.SetUserGroupsFn != nil {
return sdm.SetUserGroupsFn(ctx, groups)
}
return nil
}
func (sdm RepoDBMock) GetUserGroups(ctx context.Context) ([]string, error) {
if sdm.GetUserGroupsFn != nil {
return sdm.GetUserGroupsFn(ctx)
}
return []string{}, nil
}
func (sdm RepoDBMock) DeleteUserData(ctx context.Context) error {
if sdm.DeleteUserDataFn != nil {
return sdm.DeleteUserDataFn(ctx)
}
return nil
}
func (sdm RepoDBMock) GetUserAPIKeyInfo(hashedKey string) (string, error) {
if sdm.GetUserAPIKeyInfoFn != nil {
return sdm.GetUserAPIKeyInfoFn(hashedKey)
}
return "", nil
}
func (sdm RepoDBMock) AddUserAPIKey(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error {
if sdm.AddUserAPIKeyFn != nil {
return sdm.AddUserAPIKeyFn(ctx, hashedKey, apiKeyDetails)
}
return nil
}
func (sdm RepoDBMock) UpdateUserAPIKeyLastUsed(ctx context.Context, hashedKey string) error {
if sdm.UpdateUserAPIKeyLastUsedFn != nil {
return sdm.UpdateUserAPIKeyLastUsedFn(ctx, hashedKey)
}
return nil
}
func (sdm RepoDBMock) DeleteUserAPIKey(ctx context.Context, id string) error {
if sdm.DeleteUserAPIKeyFn != nil {
return sdm.DeleteUserAPIKeyFn(ctx, id)
}
return nil
}