diff --git a/examples/README.md b/examples/README.md index 04dcf359..16efa6ce 100644 --- a/examples/README.md +++ b/examples/README.md @@ -936,7 +936,7 @@ Configure each registry sync: ] }, { - "urls": ["https://docker.io/library"], + "urls": ["https://index.docker.io"], "onDemand": true, # doesn't have content, don't periodically pull, pull just on demand. "tlsVerify": true, "maxRetries": 3, diff --git a/pkg/api/authn.go b/pkg/api/authn.go index 1653148c..47724198 100644 --- a/pkg/api/authn.go +++ b/pkg/api/authn.go @@ -482,7 +482,7 @@ func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc { if err != nil { ctlr.Log.Error().Err(err).Msg("failed to parse Authorization header") response.Header().Set("Content-Type", "application/json") - zcommon.WriteJSON(response, http.StatusInternalServerError, apiErr.NewError(apiErr.UNSUPPORTED)) + zcommon.WriteJSON(response, http.StatusUnauthorized, apiErr.NewError(apiErr.UNSUPPORTED)) return } diff --git a/pkg/api/controller_test.go b/pkg/api/controller_test.go index ddbd60e8..1f78dfa3 100644 --- a/pkg/api/controller_test.go +++ b/pkg/api/controller_test.go @@ -3114,7 +3114,7 @@ func TestBearerAuth(t *testing.T) { Get(baseURL + "/v2/") So(err, ShouldBeNil) So(resp, ShouldNotBeNil) - So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) resp, err = resty.R().SetHeader("Authorization", fmt.Sprintf("Bearer %s", goodToken.AccessToken)).Options(baseURL + "/v2/") diff --git a/pkg/api/routes.go b/pkg/api/routes.go index 6e09a4af..7aa5dde0 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -240,7 +240,7 @@ func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, reques response.Header().Set(constants.DistAPIVersion, "registry/2.0") // NOTE: compatibility workaround - return this header in "allowed-read" mode to allow for clients to // work correctly - if rh.c.Config.HTTP.Auth != nil { + if rh.c.Config.IsBasicAuthnEnabled() || rh.c.Config.IsBearerAuthEnabled() { // don't send auth headers if request is coming from UI if request.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue { if rh.c.Config.HTTP.Auth.Bearer != nil { diff --git a/pkg/common/http_client.go b/pkg/common/http_client.go index d67035d1..a4c62712 100644 --- a/pkg/common/http_client.go +++ b/pkg/common/http_client.go @@ -1,18 +1,12 @@ package common import ( - "context" "crypto/tls" "crypto/x509" - "encoding/json" - "errors" - "io" "net/http" "os" "path" "path/filepath" - - "zotregistry.dev/zot/pkg/log" ) func GetTLSConfig(certsPath string, caCertPool *x509.CertPool) (*tls.Config, error) { @@ -107,57 +101,7 @@ func CreateHTTPClient(verifyTLS bool, host string, certDir string) (*http.Client } return &http.Client{ - Timeout: httpTimeout, Transport: htr, + Timeout: httpTimeout, }, nil } - -func MakeHTTPGetRequest(ctx context.Context, httpClient *http.Client, - username string, password string, resultPtr interface{}, - blobURL string, mediaType string, log log.Logger, -) ([]byte, string, int, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, blobURL, nil) //nolint - if err != nil { - return nil, "", 0, err - } - - if mediaType != "" { - req.Header.Set("Accept", mediaType) - } - - if username != "" && password != "" { - req.SetBasicAuth(username, password) - } - - resp, err := httpClient.Do(req) - if err != nil { - log.Error().Str("errorType", TypeOf(err)). - Err(err).Str("blobURL", blobURL).Msg("couldn't get blob") - - return nil, "", -1, err - } - - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Error().Str("errorType", TypeOf(err)). - Err(err).Str("blobURL", blobURL).Msg("couldn't get blob") - - return nil, "", resp.StatusCode, err - } - - if resp.StatusCode != http.StatusOK { - return nil, "", resp.StatusCode, errors.New(string(body)) //nolint:goerr113 - } - - // read blob - if len(body) > 0 { - err = json.Unmarshal(body, &resultPtr) - if err != nil { - return body, "", resp.StatusCode, err - } - } - - return body, resp.Header.Get("Content-Type"), resp.StatusCode, err -} diff --git a/pkg/common/http_client_test.go b/pkg/common/http_client_test.go index 63bcf320..058e7505 100644 --- a/pkg/common/http_client_test.go +++ b/pkg/common/http_client_test.go @@ -1,19 +1,14 @@ package common_test import ( - "context" "crypto/x509" "os" "path" "testing" - ispec "github.com/opencontainers/image-spec/specs-go/v1" . "github.com/smartystreets/goconvey/convey" - "zotregistry.dev/zot/pkg/api" - "zotregistry.dev/zot/pkg/api/config" "zotregistry.dev/zot/pkg/common" - "zotregistry.dev/zot/pkg/log" test "zotregistry.dev/zot/pkg/test/common" ) @@ -54,30 +49,4 @@ func TestHTTPClient(t *testing.T) { _, err = common.CreateHTTPClient(true, "localhost", tempDir) So(err, ShouldNotBeNil) }) - - Convey("test MakeHTTPGetRequest() no permissions on key", t, func() { - port := test.GetFreePort() - baseURL := test.GetBaseURL(port) - - conf := config.New() - conf.HTTP.Port = port - - ctlr := api.NewController(conf) - tempDir := t.TempDir() - err := test.CopyTestKeysAndCerts(tempDir) - So(err, ShouldBeNil) - ctlr.Config.Storage.RootDirectory = tempDir - - cm := test.NewControllerManager(ctlr) - cm.StartServer() - defer cm.StopServer() - test.WaitTillServerReady(baseURL) - - var resultPtr interface{} - httpClient, err := common.CreateHTTPClient(true, "localhost", tempDir) - So(err, ShouldBeNil) - _, _, _, err = common.MakeHTTPGetRequest(context.Background(), httpClient, "", "", - resultPtr, baseURL+"/v2/", ispec.MediaTypeImageManifest, log.NewLogger("", "")) - So(err, ShouldBeNil) - }) } diff --git a/pkg/extensions/extension_sync.go b/pkg/extensions/extension_sync.go index 7732a126..04d17640 100644 --- a/pkg/extensions/extension_sync.go +++ b/pkg/extensions/extension_sync.go @@ -50,6 +50,8 @@ func EnableSyncExtension(config *config.Config, metaDB mTypes.MetaDB, service, err := sync.New(registryConfig, credsPath, tmpDir, storeController, metaDB, log) if err != nil { + log.Error().Err(err).Msg("failed to initialize sync extension") + return nil, err } diff --git a/pkg/extensions/sync/httpclient/cache.go b/pkg/extensions/sync/httpclient/cache.go new file mode 100644 index 00000000..e0aa56de --- /dev/null +++ b/pkg/extensions/sync/httpclient/cache.go @@ -0,0 +1,58 @@ +package client + +import ( + "sync" +) + +// Key:Value store for bearer tokens, key is namespace, value is token. +// We are storing only pull scoped tokens, the http client is for pulling only. +type TokenCache struct { + entries sync.Map +} + +func NewTokenCache() *TokenCache { + return &TokenCache{ + entries: sync.Map{}, + } +} + +func (c *TokenCache) Set(namespace string, token *bearerToken) { + if c == nil || token == nil { + return + } + + defer c.prune() + + c.entries.Store(namespace, token) +} + +func (c *TokenCache) Get(namespace string) *bearerToken { + if c == nil { + return nil + } + + val, ok := c.entries.Load(namespace) + if !ok { + return nil + } + + bearerToken, ok := val.(*bearerToken) + if !ok { + return nil + } + + return bearerToken +} + +func (c *TokenCache) prune() { + c.entries.Range(func(key, val any) bool { + bearerToken, ok := val.(*bearerToken) + if ok { + if bearerToken.isExpired() { + c.entries.Delete(key) + } + } + + return true + }) +} diff --git a/pkg/extensions/sync/httpclient/client.go b/pkg/extensions/sync/httpclient/client.go index f93b3ef1..4b968d61 100644 --- a/pkg/extensions/sync/httpclient/client.go +++ b/pkg/extensions/sync/httpclient/client.go @@ -2,15 +2,56 @@ package client import ( "context" + "encoding/json" + "errors" "io" "net/http" "net/url" + "strings" "sync" + "time" + zerr "zotregistry.dev/zot/errors" "zotregistry.dev/zot/pkg/common" "zotregistry.dev/zot/pkg/log" ) +const ( + minimumTokenLifetimeSeconds = 60 // in seconds + pingTimeout = 5 * time.Second + // tokenBuffer is used to renew a token before it actually expires + // to account for the time to process requests on the server. + tokenBuffer = 5 * time.Second +) + +type authType int + +const ( + noneAuth authType = iota + basicAuth + tokenAuth +) + +type challengeParams struct { + realm string + service string + scope string + err string +} + +type bearerToken struct { + Token string `json:"token"` //nolint: tagliatelle + AccessToken string `json:"access_token"` //nolint: tagliatelle + ExpiresIn int `json:"expires_in"` //nolint: tagliatelle + IssuedAt time.Time `json:"issued_at"` //nolint: tagliatelle + expirationTime time.Time +} + +func (token *bearerToken) isExpired() bool { + // use tokenBuffer to expire it a bit earlier + return time.Now().After(token.expirationTime.Add(-1 * tokenBuffer)) +} + type Config struct { URL string Username string @@ -20,15 +61,20 @@ type Config struct { } type Client struct { - config *Config - client *http.Client - url *url.URL - lock *sync.RWMutex - log log.Logger + config *Config + client *http.Client + url *url.URL + authType authType + cache *TokenCache + lock *sync.RWMutex + log log.Logger } func New(config Config, log log.Logger) (*Client, error) { client := &Client{log: log, lock: new(sync.RWMutex)} + + client.cache = NewTokenCache() + if err := client.SetConfig(config); err != nil { return nil, err } @@ -50,6 +96,13 @@ func (httpClient *Client) GetHostname() string { return httpClient.url.Host } +func (httpClient *Client) GetBaseURL() string { + httpClient.lock.RLock() + defer httpClient.lock.RUnlock() + + return httpClient.url.String() +} + func (httpClient *Client) SetConfig(config Config) error { httpClient.lock.Lock() defer httpClient.lock.Unlock() @@ -73,41 +126,30 @@ func (httpClient *Client) SetConfig(config Config) error { } func (httpClient *Client) Ping() bool { - httpClient.lock.RLock() - defer httpClient.lock.RUnlock() + httpClient.lock.Lock() + defer httpClient.lock.Unlock() pingURL := *httpClient.url pingURL = *pingURL.JoinPath("/v2/") - req, err := http.NewRequest(http.MethodGet, pingURL.String(), nil) //nolint + // for the ping function we want to timeout fast + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + + //nolint: bodyclose + resp, _, err := httpClient.get(ctx, pingURL.String(), false) if err != nil { return false } - resp, err := httpClient.client.Do(req) - if err != nil { - httpClient.log.Error().Err(err).Str("url", pingURL.String()).Str("component", "sync"). - Msg("failed to ping registry") + httpClient.getAuthType(resp) - return false - } - - defer resp.Body.Close() - - if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusUnauthorized { + if resp.StatusCode >= http.StatusOK && resp.StatusCode <= http.StatusForbidden { return true } - body, err := io.ReadAll(resp.Body) - if err != nil { - httpClient.log.Error().Err(err).Str("url", pingURL.String()). - Msg("failed to read body while pinging registry") - - return false - } - - httpClient.log.Error().Str("url", pingURL.String()).Str("body", string(body)).Int("statusCode", resp.StatusCode). + httpClient.log.Error().Str("url", pingURL.String()).Int("statusCode", resp.StatusCode). Str("component", "sync").Msg("failed to ping registry") return false @@ -119,17 +161,302 @@ func (httpClient *Client) MakeGetRequest(ctx context.Context, resultPtr interfac httpClient.lock.RLock() defer httpClient.lock.RUnlock() - url := *httpClient.url + var namespace string - for _, r := range route { - url = *url.JoinPath(r) + url := *httpClient.url + for idx, path := range route { + url = *url.JoinPath(path) + + // we know that the second route argument is always the repo name. + // need it for caching tokens, it's not used in requests made to authz server. + if idx == 1 { + namespace = path + } } url.RawQuery = url.Query().Encode() + //nolint: bodyclose,contextcheck + resp, body, err := httpClient.makeAndDoRequest(http.MethodGet, mediaType, namespace, url.String()) + if err != nil { + httpClient.log.Error().Err(err).Str("url", url.String()).Str("component", "sync"). + Str("errorType", common.TypeOf(err)). + Msg("failed to make request") - body, mediaType, statusCode, err := common.MakeHTTPGetRequest(ctx, httpClient.client, httpClient.config.Username, - httpClient.config.Password, resultPtr, - url.String(), mediaType, httpClient.log) + return nil, "", -1, err + } - return body, mediaType, statusCode, err + if resp.StatusCode != http.StatusOK { + return nil, "", resp.StatusCode, errors.New(string(body)) //nolint:goerr113 + } + + // read blob + if len(body) > 0 { + err = json.Unmarshal(body, &resultPtr) + } + + return body, resp.Header.Get("Content-Type"), resp.StatusCode, err +} + +func (httpClient *Client) getAuthType(resp *http.Response) { + authHeader := resp.Header.Get("www-authenticate") + + authHeaderLower := strings.ToLower(authHeader) + + //nolint: gocritic + if strings.Contains(authHeaderLower, "bearer") { + httpClient.authType = tokenAuth + } else if strings.Contains(authHeaderLower, "basic") { + httpClient.authType = basicAuth + } else { + httpClient.authType = noneAuth + } +} + +func (httpClient *Client) setupAuth(req *http.Request, namespace string) error { + if httpClient.authType == tokenAuth { + token, err := httpClient.getToken(req.URL.String(), namespace) + if err != nil { + httpClient.log.Error().Err(err).Str("url", req.URL.String()).Str("component", "sync"). + Str("errorType", common.TypeOf(err)). + Msg("failed to get token from authorization realm") + + return err + } + + req.Header.Set("Authorization", "Bearer "+token.Token) + } else if httpClient.authType == basicAuth { + req.SetBasicAuth(httpClient.config.Username, httpClient.config.Password) + } + + return nil +} + +func (httpClient *Client) get(ctx context.Context, url string, setAuth bool) (*http.Response, []byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) //nolint + if err != nil { + return nil, nil, err + } + + if setAuth && httpClient.config.Username != "" && httpClient.config.Password != "" { + req.SetBasicAuth(httpClient.config.Username, httpClient.config.Password) + } + + return httpClient.doRequest(req) +} + +func (httpClient *Client) doRequest(req *http.Request) (*http.Response, []byte, error) { + resp, err := httpClient.client.Do(req) + if err != nil { + httpClient.log.Error().Err(err).Str("url", req.URL.String()).Str("component", "sync"). + Str("errorType", common.TypeOf(err)). + Msg("failed to make request") + + return nil, nil, err + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + httpClient.log.Error().Err(err).Str("url", req.URL.String()). + Str("errorType", common.TypeOf(err)). + Msg("failed to read body") + + return nil, nil, err + } + + return resp, body, nil +} + +func (httpClient *Client) makeAndDoRequest(method, mediaType, namespace, urlStr string, +) (*http.Response, []byte, error) { + req, err := http.NewRequest(method, urlStr, nil) //nolint + if err != nil { + return nil, nil, err + } + + if err := httpClient.setupAuth(req, namespace); err != nil { + return nil, nil, err + } + + if mediaType != "" { + req.Header.Set("Accept", mediaType) + } + + resp, body, err := httpClient.doRequest(req) + if err != nil { + return nil, nil, err + } + + // let's retry one time if we get an insufficient_scope error + if ok, challengeParams := needsRetryWithUpdatedScope(err, resp); ok { + var tokenURL *url.URL + + var token *bearerToken + + tokenURL, err = getTokenURLFromChallengeParams(challengeParams, httpClient.config.Username) + if err != nil { + return nil, nil, err + } + + token, err = httpClient.getTokenFromURL(tokenURL.String(), namespace) + if err != nil { + return nil, nil, err + } + + req.Header.Set("Authorization", "Bearer "+token.Token) + + resp, body, err = httpClient.doRequest(req) + } + + return resp, body, err +} + +func (httpClient *Client) getTokenFromURL(urlStr, namespace string) (*bearerToken, error) { + //nolint: bodyclose + resp, body, err := httpClient.get(context.Background(), urlStr, true) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, zerr.ErrUnauthorizedAccess + } + + token, err := newBearerToken(body) + if err != nil { + return nil, err + } + + // cache it + httpClient.cache.Set(namespace, token) + + return token, nil +} + +// Gets bearer token from Authorization realm. +func (httpClient *Client) getToken(urlStr, namespace string) (*bearerToken, error) { + // first check cache + token := httpClient.cache.Get(namespace) + if token != nil && !token.isExpired() { + return token, nil + } + + //nolint: bodyclose + resp, _, err := httpClient.get(context.Background(), urlStr, false) + if err != nil { + return nil, err + } + + challengeParams, err := parseAuthHeader(resp) + if err != nil { + return nil, err + } + + tokenURL, err := getTokenURLFromChallengeParams(challengeParams, httpClient.config.Username) + if err != nil { + return nil, err + } + + return httpClient.getTokenFromURL(tokenURL.String(), namespace) +} + +func newBearerToken(blob []byte) (*bearerToken, error) { + token := new(bearerToken) + if err := json.Unmarshal(blob, &token); err != nil { + return nil, err + } + + if token.Token == "" { + token.Token = token.AccessToken + } + + if token.ExpiresIn < minimumTokenLifetimeSeconds { + token.ExpiresIn = minimumTokenLifetimeSeconds + } + + if token.IssuedAt.IsZero() { + token.IssuedAt = time.Now().UTC() + } + + token.expirationTime = token.IssuedAt.Add(time.Duration(token.ExpiresIn) * time.Second) + + return token, nil +} + +func getTokenURLFromChallengeParams(params challengeParams, account string) (*url.URL, error) { + parsedRealm, err := url.Parse(params.realm) + if err != nil { + return nil, err + } + + query := parsedRealm.Query() + query.Set("service", params.service) + query.Set("scope", params.scope) + + if account != "" { + query.Set("account", account) + } + + parsedRealm.RawQuery = query.Encode() + + return parsedRealm, nil +} + +func parseAuthHeader(resp *http.Response) (challengeParams, error) { + authHeader := resp.Header.Get("www-authenticate") + + authHeaderSlice := strings.Split(authHeader, ",") + + params := challengeParams{} + + for _, elem := range authHeaderSlice { + if strings.Contains(strings.ToLower(elem), "bearer") { + elem = strings.Split(elem, " ")[1] + } + + elem := strings.ReplaceAll(elem, "\"", "") + + elemSplit := strings.Split(elem, "=") + if len(elemSplit) != 2 { //nolint: gomnd + return params, zerr.ErrParsingAuthHeader + } + + authKey := elemSplit[0] + + authValue := elemSplit[1] + + switch authKey { + case "realm": + params.realm = authValue + case "service": + params.service = authValue + case "scope": + params.scope = authValue + case "error": + params.err = authValue + } + } + + return params, nil +} + +// Checks if the auth headers in the response contain an indication of a failed +// authorization because of an "insufficient_scope" error. +func needsRetryWithUpdatedScope(err error, resp *http.Response) (bool, challengeParams) { + params := challengeParams{} + if err == nil && resp.StatusCode == http.StatusUnauthorized { + params, err = parseAuthHeader(resp) + if err != nil { + return false, params + } + + if params.err == "insufficient_scope" { + if params.scope != "" { + return true, params + } + } + } + + return false, params } diff --git a/pkg/extensions/sync/httpclient/client_internal_test.go b/pkg/extensions/sync/httpclient/client_internal_test.go new file mode 100644 index 00000000..48ec04da --- /dev/null +++ b/pkg/extensions/sync/httpclient/client_internal_test.go @@ -0,0 +1,167 @@ +package client + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" + + "zotregistry.dev/zot/pkg/log" +) + +func TestTokenCache(t *testing.T) { + Convey("Get/Set tokens", t, func() { + tokenCache := NewTokenCache() + token := &bearerToken{ + Token: "tokenA", + ExpiresIn: 3, + IssuedAt: time.Now(), + } + + token.expirationTime = token.IssuedAt.Add(time.Duration(token.ExpiresIn) * time.Second).Add(tokenBuffer) + + tokenCache.Set("repo", token) + cachedToken := tokenCache.Get("repo") + So(cachedToken.Token, ShouldEqual, token.Token) + + // add token which expires soon + token2 := &bearerToken{ + Token: "tokenB", + ExpiresIn: 1, + IssuedAt: time.Now(), + } + + token2.expirationTime = token2.IssuedAt.Add(time.Duration(token2.ExpiresIn) * time.Second).Add(tokenBuffer) + + tokenCache.Set("repo2", token2) + cachedToken = tokenCache.Get("repo2") + So(cachedToken.Token, ShouldEqual, token2.Token) + + time.Sleep(1 * time.Second) + + // token3 should be expired when adding a new one + token3 := &bearerToken{ + Token: "tokenC", + ExpiresIn: 3, + IssuedAt: time.Now(), + } + + token3.expirationTime = token3.IssuedAt.Add(time.Duration(token3.ExpiresIn) * time.Second).Add(tokenBuffer) + + tokenCache.Set("repo3", token3) + cachedToken = tokenCache.Get("repo3") + So(cachedToken.Token, ShouldEqual, token3.Token) + + // token2 should be expired + token = tokenCache.Get("repo2") + So(token, ShouldBeNil) + + time.Sleep(2 * time.Second) + + // the rest of them should also be expired + tokenCache.Set("repo4", &bearerToken{ + Token: "tokenD", + }) + + // token1 should be expired + token = tokenCache.Get("repo1") + So(token, ShouldBeNil) + }) + + Convey("Error paths", t, func() { + tokenCache := NewTokenCache() + token := tokenCache.Get("repo") + So(token, ShouldBeNil) + + tokenCache = nil + token = tokenCache.Get("repo") + So(token, ShouldBeNil) + + tokenCache = NewTokenCache() + tokenCache.Set("repo", nil) + token = tokenCache.Get("repo") + So(token, ShouldBeNil) + }) +} + +func TestNeedsRetryOnInsuficientScope(t *testing.T) { + resp := http.Response{ + Status: "401 Unauthorized", + StatusCode: http.StatusUnauthorized, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: map[string][]string{ + "Content-Length": {"145"}, + "Content-Type": {"application/json"}, + "Date": {"Fri, 26 Aug 2022 08:03:13 GMT"}, + "X-Content-Type-Options": {"nosniff"}, + }, + Request: nil, + } + + Convey("Test client retries on insufficient scope", t, func() { + resp.Header["Www-Authenticate"] = []string{ + `Bearer realm="https://registry.suse.com/auth",service="SUSE Linux Docker Registry"` + + `,scope="registry:catalog:*",error="insufficient_scope"`, + } + + expectedScope := "registry:catalog:*" + expectedRealm := "https://registry.suse.com/auth" + expectedService := "SUSE Linux Docker Registry" + + needsRetry, params := needsRetryWithUpdatedScope(nil, &resp) + + So(needsRetry, ShouldBeTrue) + So(params.scope, ShouldEqual, expectedScope) + So(params.realm, ShouldEqual, expectedRealm) + So(params.service, ShouldEqual, expectedService) + }) + + Convey("Test client fails on insufficient scope", t, func() { + resp.Header["Www-Authenticate"] = []string{ + `Bearer realm="https://registry.suse.com/auth=error"`, + } + + needsRetry, _ := needsRetryWithUpdatedScope(nil, &resp) + So(needsRetry, ShouldBeFalse) + }) +} + +func TestClient(t *testing.T) { + Convey("Test client", t, func() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + client, err := New(Config{ + URL: server.URL, + TLSVerify: false, + }, log.NewLogger("", "")) + So(err, ShouldBeNil) + + Convey("Test Ping() fails", func() { + ok := client.Ping() + So(ok, ShouldBeFalse) + }) + + Convey("Test makeAndDoRequest() fails", func() { + client.authType = tokenAuth + //nolint: bodyclose + _, _, err := client.makeAndDoRequest(http.MethodGet, "application/json", "catalog", server.URL) + So(err, ShouldNotBeNil) + }) + + Convey("Test setupAuth() fails", func() { + request, err := http.NewRequest(http.MethodGet, server.URL, nil) //nolint: noctx + So(err, ShouldBeNil) + + client.authType = tokenAuth + err = client.setupAuth(request, "catalog") + So(err, ShouldNotBeNil) + }) + }) +} diff --git a/pkg/extensions/sync/on_demand.go b/pkg/extensions/sync/on_demand.go index 23445ec2..02ef21be 100644 --- a/pkg/extensions/sync/on_demand.go +++ b/pkg/extensions/sync/on_demand.go @@ -109,14 +109,20 @@ func (onDemand *BaseOnDemand) syncImage(ctx context.Context, repo, reference str var err error for serviceID, service := range onDemand.services { err = service.SetNextAvailableURL() - if err != nil { + + isPingErr := errors.Is(err, zerr.ErrSyncPingRegistry) + if err != nil && !isPingErr { syncResult <- err return } - err = service.SyncImage(ctx, repo, reference) - if err != nil { + // no need to try to sync inline if there is a ping error, we want to retry in background + if !isPingErr { + err = service.SyncImage(ctx, repo, reference) + } + + if err != nil || isPingErr { if errors.Is(err, zerr.ErrManifestNotFound) || errors.Is(err, zerr.ErrSyncImageFilteredOut) || errors.Is(err, zerr.ErrSyncImageNotSigned) { diff --git a/pkg/extensions/sync/remote.go b/pkg/extensions/sync/remote.go index a146eae6..fef205d4 100644 --- a/pkg/extensions/sync/remote.go +++ b/pkg/extensions/sync/remote.go @@ -6,6 +6,7 @@ package sync import ( "context" "fmt" + "strings" "github.com/containers/image/v5/docker" dockerReference "github.com/containers/image/v5/docker/reference" @@ -58,6 +59,26 @@ func (registry *RemoteRegistry) GetRepositories(ctx context.Context) ([]string, return catalog.Repositories, nil } +func (registry *RemoteRegistry) GetDockerRemoteRepo(repo string) string { + dockerNamespace := "library" + dockerRegistry := "docker.io" + + remoteHost := registry.client.GetHostname() + + repoRef, err := parseRepositoryReference(fmt.Sprintf("%s/%s", remoteHost, repo)) + if err != nil { + return repo + } + + if !strings.Contains(repo, dockerNamespace) && + strings.Contains(repoRef.String(), dockerNamespace) && + strings.Contains(repoRef.String(), dockerRegistry) { + return fmt.Sprintf("%s/%s", dockerNamespace, repo) + } + + return repo +} + func (registry *RemoteRegistry) GetImageReference(repo, reference string) (types.ImageReference, error) { remoteHost := registry.client.GetHostname() diff --git a/pkg/extensions/sync/service.go b/pkg/extensions/sync/service.go index 9ce22ae3..e0caedf6 100644 --- a/pkg/extensions/sync/service.go +++ b/pkg/extensions/sync/service.go @@ -93,9 +93,12 @@ func New( service.retryOptions = retryOptions service.storeController = storeController - err = service.SetNextAvailableClient() - if err != nil { - return nil, err + // try to set next client. + if err := service.SetNextAvailableClient(); err != nil { + // if it's a ping issue, it will be retried + if !errors.Is(err, zerr.ErrSyncPingRegistry) { + return service, err + } } service.references = references.NewReferences( @@ -118,7 +121,14 @@ func (service *BaseService) SetNextAvailableClient() error { return nil } + found := false + for _, url := range service.config.URLs { + // skip current client + if service.client != nil && service.client.GetBaseURL() == url { + continue + } + remoteAddress := StripRegistryTransport(url) credentials := service.credentials[remoteAddress] @@ -149,12 +159,14 @@ func (service *BaseService) SetNextAvailableClient() error { return err } - if !service.client.Ping() { - continue + if service.client.Ping() { + found = true + + break } } - if service.client == nil { + if service.client == nil || !found { return zerr.ErrSyncPingRegistry } @@ -241,6 +253,8 @@ func (service *BaseService) SyncReference(ctx context.Context, repo string, } } + remoteRepo = service.remote.GetDockerRemoteRepo(remoteRepo) + service.log.Info().Str("remote", remoteURL).Str("repository", repo).Str("subject", subjectDigestStr). Str("reference type", referenceType).Msg("syncing reference for image") @@ -263,6 +277,8 @@ func (service *BaseService) SyncImage(ctx context.Context, repo, reference strin } } + remoteRepo = service.remote.GetDockerRemoteRepo(remoteRepo) + service.log.Info().Str("remote", remoteURL).Str("repository", repo).Str("reference", reference). Msg("syncing image") diff --git a/pkg/extensions/sync/sync.go b/pkg/extensions/sync/sync.go index 986fbb24..4743ee34 100644 --- a/pkg/extensions/sync/sync.go +++ b/pkg/extensions/sync/sync.go @@ -63,6 +63,9 @@ type Remote interface { GetRepoTags(repo string) ([]string, error) // Get manifest content, mediaType, digest given an ImageReference GetManifestContent(imageReference types.ImageReference) ([]byte, string, digest.Digest, error) + // In the case of public dockerhub images 'library' namespace is added to the repo names of images + // eg: alpine -> library/alpine + GetDockerRemoteRepo(repo string) string } // Local registry. @@ -111,6 +114,11 @@ func (gen *TaskGenerator) Next() (scheduler.Task, error) { return nil, nil } + // a task with this repo is already running + if gen.lastRepo == repo { + return nil, nil + } + gen.lastRepo = repo return newSyncRepoTask(gen.lastRepo, gen.Service), nil diff --git a/pkg/extensions/sync/sync_test.go b/pkg/extensions/sync/sync_test.go index fdcf4910..f4f6f00f 100644 --- a/pkg/extensions/sync/sync_test.go +++ b/pkg/extensions/sync/sync_test.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "os" "os/exec" "path" @@ -47,6 +48,7 @@ import ( "zotregistry.dev/zot/pkg/log" mTypes "zotregistry.dev/zot/pkg/meta/types" storageConstants "zotregistry.dev/zot/pkg/storage/constants" + authutils "zotregistry.dev/zot/pkg/test/auth" test "zotregistry.dev/zot/pkg/test/common" . "zotregistry.dev/zot/pkg/test/image-utils" "zotregistry.dev/zot/pkg/test/mocks" @@ -2364,6 +2366,284 @@ func TestTLS(t *testing.T) { }) } +func TestBearerAuth(t *testing.T) { + Convey("Verify periodically sync bearer auth", t, func() { + updateDuration, _ := time.ParseDuration("1h") + // a repo for which clients do not have access, sync shouldn't be able to sync it + unauthorizedNamespace := testCveImage + + authTestServer := authutils.MakeAuthTestServer(ServerKey, unauthorizedNamespace) + defer authTestServer.Close() + + sctlr, srcBaseURL, _, _, srcClient := makeUpstreamServer(t, false, false) + + aurl, err := url.Parse(authTestServer.URL) + So(err, ShouldBeNil) + + sctlr.Config.HTTP.Auth = &config.AuthConfig{ + Bearer: &config.BearerConfig{ + Cert: ServerCert, + Realm: authTestServer.URL + "/auth/token", + Service: aurl.Host, + }, + } + + scm := test.NewControllerManager(sctlr) + scm.StartAndWait(sctlr.Config.HTTP.Port) + defer scm.StopServer() + + registryName := sync.StripRegistryTransport(srcBaseURL) + credentialsFile := makeCredentialsFile(fmt.Sprintf(`{"%s":{"username": "%s", "password": "%s"}}`, + registryName, username, password)) + + var tlsVerify bool + + syncRegistryConfig := syncconf.RegistryConfig{ + Content: []syncconf.Content{ + { + Prefix: "**", // sync everything + }, + }, + URLs: []string{srcBaseURL}, + PollInterval: updateDuration, + TLSVerify: &tlsVerify, + CertDir: "", + } + + defaultVal := true + syncConfig := &syncconf.Config{ + Enable: &defaultVal, + CredentialsFile: credentialsFile, + Registries: []syncconf.RegistryConfig{syncRegistryConfig}, + } + + dctlr, destBaseURL, _, destClient := makeDownstreamServer(t, false, syncConfig) + + dcm := test.NewControllerManager(dctlr) + dcm.StartAndWait(dctlr.Config.HTTP.Port) + defer dcm.StopServer() + + var srcTagsList TagsList + var destTagsList TagsList + + resp, err := srcClient.R().Get(srcBaseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + + authorizationHeader := authutils.ParseBearerAuthHeader(resp.Header().Get("WWW-Authenticate")) + resp, err = resty.R(). + SetQueryParam("service", authorizationHeader.Service). + Get(authorizationHeader.Realm) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + var goodToken authutils.AccessTokenResponse + err = json.Unmarshal(resp.Body(), &goodToken) + So(err, ShouldBeNil) + + resp, err = srcClient.R(). + SetHeader("Authorization", fmt.Sprintf("Bearer %s", goodToken.AccessToken)). + Get(srcBaseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + resp, err = srcClient.R().Get(srcBaseURL + "/v2/" + testImage + "/tags/list") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + + authorizationHeader = authutils.ParseBearerAuthHeader(resp.Header().Get("WWW-Authenticate")) + resp, err = resty.R(). + SetQueryParam("service", authorizationHeader.Service). + SetQueryParam("scope", authorizationHeader.Scope). + Get(authorizationHeader.Realm) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + goodToken = authutils.AccessTokenResponse{} + err = json.Unmarshal(resp.Body(), &goodToken) + So(err, ShouldBeNil) + + resp, err = srcClient.R().SetHeader("Authorization", fmt.Sprintf("Bearer %s", goodToken.AccessToken)). + Get(srcBaseURL + "/v2/" + testImage + "/tags/list") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + err = json.Unmarshal(resp.Body(), &srcTagsList) + if err != nil { + panic(err) + } + + for { + resp, err = destClient.R().Get(destBaseURL + "/v2/" + testImage + "/tags/list") + if err != nil { + panic(err) + } + + err = json.Unmarshal(resp.Body(), &destTagsList) + if err != nil { + panic(err) + } + + if len(destTagsList.Tags) > 0 { + break + } + + time.Sleep(500 * time.Millisecond) + } + + So(destTagsList, ShouldResemble, srcTagsList) + + waitSyncFinish(dctlr.Config.Log.Output) + + resp, err = destClient.R().Get(destBaseURL + "/v2/" + testImage + "/manifests/" + testImageTag) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // unauthorized namespace + resp, err = destClient.R().Get(destBaseURL + "/v2/" + testCveImage + "/manifests/" + testImageTag) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) + }) + + Convey("Verify ondemand sync bearer auth", t, func() { + // a repo for which clients do not have access, sync shouldn't be able to sync it + unauthorizedNamespace := testCveImage + + authTestServer := authutils.MakeAuthTestServer(ServerKey, unauthorizedNamespace) + defer authTestServer.Close() + + sctlr, srcBaseURL, _, _, srcClient := makeUpstreamServer(t, false, false) + + aurl, err := url.Parse(authTestServer.URL) + So(err, ShouldBeNil) + + sctlr.Config.HTTP.Auth = &config.AuthConfig{ + Bearer: &config.BearerConfig{ + Cert: ServerCert, + Realm: authTestServer.URL + "/auth/token", + Service: aurl.Host, + }, + } + + scm := test.NewControllerManager(sctlr) + scm.StartAndWait(sctlr.Config.HTTP.Port) + defer scm.StopServer() + + registryName := sync.StripRegistryTransport(srcBaseURL) + credentialsFile := makeCredentialsFile(fmt.Sprintf(`{"%s":{"username": "%s", "password": "%s"}}`, + registryName, username, password)) + + var tlsVerify bool + + syncRegistryConfig := syncconf.RegistryConfig{ + Content: []syncconf.Content{ + { + Prefix: "**", // sync everything + }, + }, + URLs: []string{srcBaseURL}, + TLSVerify: &tlsVerify, + OnDemand: true, + CertDir: "", + } + + defaultVal := true + syncConfig := &syncconf.Config{ + Enable: &defaultVal, + CredentialsFile: credentialsFile, + Registries: []syncconf.RegistryConfig{syncRegistryConfig}, + } + + dctlr, destBaseURL, _, destClient := makeDownstreamServer(t, false, syncConfig) + + dcm := test.NewControllerManager(dctlr) + dcm.StartAndWait(dctlr.Config.HTTP.Port) + defer dcm.StopServer() + + var srcTagsList TagsList + var destTagsList TagsList + + resp, err := srcClient.R().Get(srcBaseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + + authorizationHeader := authutils.ParseBearerAuthHeader(resp.Header().Get("WWW-Authenticate")) + resp, err = resty.R(). + SetQueryParam("service", authorizationHeader.Service). + Get(authorizationHeader.Realm) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + var goodToken authutils.AccessTokenResponse + err = json.Unmarshal(resp.Body(), &goodToken) + So(err, ShouldBeNil) + + resp, err = srcClient.R(). + SetHeader("Authorization", fmt.Sprintf("Bearer %s", goodToken.AccessToken)). + Get(srcBaseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + resp, err = srcClient.R().Get(srcBaseURL + "/v2/" + testImage + "/tags/list") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + + authorizationHeader = authutils.ParseBearerAuthHeader(resp.Header().Get("WWW-Authenticate")) + resp, err = resty.R(). + SetQueryParam("service", authorizationHeader.Service). + SetQueryParam("scope", authorizationHeader.Scope). + Get(authorizationHeader.Realm) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + goodToken = authutils.AccessTokenResponse{} + err = json.Unmarshal(resp.Body(), &goodToken) + So(err, ShouldBeNil) + + resp, err = srcClient.R().SetHeader("Authorization", fmt.Sprintf("Bearer %s", goodToken.AccessToken)). + Get(srcBaseURL + "/v2/" + testImage + "/tags/list") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + err = json.Unmarshal(resp.Body(), &srcTagsList) + if err != nil { + panic(err) + } + + // sync on demand + resp, err = destClient.R().Get(destBaseURL + "/v2/" + testImage + "/manifests/" + testImageTag) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + resp, err = destClient.R().Get(destBaseURL + "/v2/" + testImage + "/tags/list") + if err != nil { + panic(err) + } + + err = json.Unmarshal(resp.Body(), &destTagsList) + if err != nil { + panic(err) + } + + So(destTagsList, ShouldResemble, srcTagsList) + + // unauthorized namespace + resp, err = destClient.R().Get(destBaseURL + "/v2/" + testCveImage + "/manifests/" + testImageTag) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) + }) +} + func TestBasicAuth(t *testing.T) { Convey("Verify sync basic auth", t, func() { updateDuration, _ := time.ParseDuration("1h") diff --git a/pkg/test/mocks/sync_remote_mock.go b/pkg/test/mocks/sync_remote_mock.go index d7500029..9c7e418b 100644 --- a/pkg/test/mocks/sync_remote_mock.go +++ b/pkg/test/mocks/sync_remote_mock.go @@ -20,10 +20,20 @@ type SyncRemote struct { // Get a list of tags given a repo GetRepoTagsFn func(repo string) ([]string, error) + GetDockerRemoteRepoFn func(repo string) string + // Get manifest content, mediaType, digest given an ImageReference GetManifestContentFn func(imageReference types.ImageReference) ([]byte, string, digest.Digest, error) } +func (remote SyncRemote) GetDockerRemoteRepo(repo string) string { + if remote.GetDockerRemoteRepoFn != nil { + return remote.GetDockerRemoteRepoFn(repo) + } + + return "" +} + func (remote SyncRemote) GetImageReference(repo string, tag string) (types.ImageReference, error) { if remote.GetImageReferenceFn != nil { return remote.GetImageReferenceFn(repo, tag)