From b8010e1ee4ae283993574190779434e1e069cb10 Mon Sep 17 00:00:00 2001 From: Shivam Mishra Date: Wed, 16 Feb 2022 01:15:13 +0000 Subject: [PATCH] routes: changes required to do browser authentication whenever we make a request that contains header apart from CORS allowed header, browser sends a preflight request and in response accept *Access-Control-Allow-Headers*. preflight request is in form of OPTIONS method, added new http handler func to set headers and returns HTTP status ok in case of OPTIONS method. in case of authorization, request contains authorization header added authorization header in Access-Control-Allow-Headers list added AllowOrigin field in HTTPConfig this field value is set to Access-Control-Allow-Origin header and will give zot adminstrator to limit incoming request. Signed-off-by: Shivam Mishra --- pkg/api/authn.go | 65 +++++++++++++++++---------- pkg/api/config/config.go | 1 + pkg/api/controller.go | 20 ++++++--- pkg/api/controller_test.go | 17 +++++++ pkg/api/routes.go | 14 +++--- pkg/extensions/extensions.go | 2 +- pkg/extensions/search/cve/cve_test.go | 62 +++++++++++++++++++++++++ 7 files changed, 147 insertions(+), 34 deletions(-) diff --git a/pkg/api/authn.go b/pkg/api/authn.go index 4662677c..1d5ee424 100644 --- a/pkg/api/authn.go +++ b/pkg/api/authn.go @@ -45,6 +45,11 @@ func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + if request.Method == http.MethodOptions { + response.WriteHeader(http.StatusNoContent) + + return + } vars := mux.Vars(request) name := vars["name"] header := request.Header.Get("Authorization") @@ -72,6 +77,37 @@ func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc { } } +func noPasswdAuth(realm string, config *config.Config) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + if request.Method == http.MethodOptions { + response.WriteHeader(http.StatusNoContent) + + return + } + + if config.HTTP.AllowReadAccess && + config.HTTP.TLS.CACert != "" && + request.TLS.VerifiedChains == nil && + request.Method != http.MethodGet && request.Method != http.MethodHead { + authFail(response, realm, 5) //nolint:gomnd + + return + } + + if (request.Method != http.MethodGet && request.Method != http.MethodHead) && config.HTTP.ReadOnly { + // Reject modification requests in read-only mode + response.WriteHeader(http.StatusMethodNotAllowed) + + return + } + + // Process request + next.ServeHTTP(response, request) + }) + } +} + // nolint:gocyclo // we use closure making this a complex subroutine func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc { realm := ctlr.Config.HTTP.Realm @@ -84,28 +120,7 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc { // no password based authN, if neither LDAP nor HTTP BASIC is enabled if ctlr.Config.HTTP.Auth == nil || (ctlr.Config.HTTP.Auth.HTPasswd.Path == "" && ctlr.Config.HTTP.Auth.LDAP == nil) { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - if ctlr.Config.HTTP.AllowReadAccess && - ctlr.Config.HTTP.TLS.CACert != "" && - request.TLS.VerifiedChains == nil && - request.Method != http.MethodGet && request.Method != http.MethodHead { - authFail(response, realm, 5) //nolint:gomnd - - return - } - - if (request.Method != http.MethodGet && request.Method != http.MethodHead) && ctlr.Config.HTTP.ReadOnly { - // Reject modification requests in read-only mode - response.WriteHeader(http.StatusMethodNotAllowed) - - return - } - - // Process request - next.ServeHTTP(response, request) - }) - } + return noPasswdAuth(realm, ctlr.Config) } credMap := make(map[string]string) @@ -177,6 +192,11 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + if request.Method == http.MethodOptions { + response.WriteHeader(http.StatusNoContent) + + return + } if (request.Method == http.MethodGet || request.Method == http.MethodHead) && ctlr.Config.HTTP.AllowReadAccess { // Process request next.ServeHTTP(response, request) @@ -185,7 +205,6 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc { } if (request.Method != http.MethodGet && request.Method != http.MethodHead) && ctlr.Config.HTTP.ReadOnly { - // Reject modification requests in read-only mode response.WriteHeader(http.StatusMethodNotAllowed) return diff --git a/pkg/api/config/config.go b/pkg/api/config/config.go index e234cb8c..1ad3347a 100644 --- a/pkg/api/config/config.go +++ b/pkg/api/config/config.go @@ -64,6 +64,7 @@ type RatelimitConfig struct { type HTTPConfig struct { Address string Port string + AllowOrigin string // comma separated TLS *TLSConfig Auth *AuthConfig RawAccessControl map[string]interface{} `mapstructure:"accessControl,omitempty"` diff --git a/pkg/api/controller.go b/pkg/api/controller.go index ad76e031..946d6c8e 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -57,19 +57,29 @@ func NewController(config *config.Config) *Controller { return &controller } -func DefaultHeaders() mux.MiddlewareFunc { +func (c *Controller) CORSHeaders() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { // CORS - response.Header().Set("Access-Control-Allow-Origin", "*") - response.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") + c.CORSHandler(response, request) - // handle the request next.ServeHTTP(response, request) }) } } +func (c *Controller) CORSHandler(response http.ResponseWriter, request *http.Request) { + // allow origin as specified in config if not accept request from anywhere. + if c.Config.HTTP.AllowOrigin == "" { + response.Header().Set("Access-Control-Allow-Origin", "*") + } else { + response.Header().Set("Access-Control-Allow-Origin", c.Config.HTTP.AllowOrigin) + } + + response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS") + response.Header().Set("Access-Control-Allow-Headers", "Authorization") +} + func DumpRuntimeParams(log log.Logger) { var rLimit syscall.Rlimit @@ -120,7 +130,7 @@ func (c *Controller) Run() error { } engine.Use( - DefaultHeaders(), + c.CORSHeaders(), SessionLogger(c), handlers.RecoveryHandler(handlers.RecoveryLogger(c.Log), handlers.PrintRecoveryStack(false))) diff --git a/pkg/api/controller_test.go b/pkg/api/controller_test.go index 037d6ea0..f43bd489 100644 --- a/pkg/api/controller_test.go +++ b/pkg/api/controller_test.go @@ -248,6 +248,9 @@ func TestHtpasswdSingleCred(t *testing.T) { Path: htpasswdPath, }, } + + conf.HTTP.AllowOrigin = conf.HTTP.Address + ctlr := api.NewController(conf) ctlr.Config.Storage.RootDirectory = t.TempDir() @@ -260,6 +263,14 @@ func TestHtpasswdSingleCred(t *testing.T) { So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusOK) + header := []string{"Authorization"} + + resp, _ = resty.R().SetBasicAuth(user, password).Options(baseURL + "/v2/") + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNoContent) + So(len(resp.Header()), ShouldEqual, 4) + So(resp.Header()["Access-Control-Allow-Headers"], ShouldResemble, header) + // with invalid creds, it should fail resp, _ = resty.R().SetBasicAuth("chuck", "chuck").Get(baseURL + "/v2/") So(resp, ShouldNotBeNil) @@ -1467,6 +1478,12 @@ func TestBearerAuth(t *testing.T) { So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusOK) + resp, err = resty.R().SetHeader("Authorization", + fmt.Sprintf("Bearer %s", goodToken.AccessToken)).Options(baseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNoContent) + resp, err = resty.R().Post(baseURL + "/v2/" + AuthorizedNamespace + "/blobs/uploads/") So(err, ShouldBeNil) So(resp, ShouldNotBeNil) diff --git a/pkg/api/routes.go b/pkg/api/routes.go index cc2ee998..647474a3 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -57,6 +57,10 @@ func NewRouteHandler(c *Controller) *RouteHandler { return rh } +func allowedMethods(method string) []string { + return []string{http.MethodOptions, method} +} + func (rh *RouteHandler) SetupRoutes() { rh.c.Router.Use(AuthHandler(rh.c)) // authz is being enabled because authn is found @@ -68,11 +72,11 @@ func (rh *RouteHandler) SetupRoutes() { prefixedRouter := rh.c.Router.PathPrefix(RoutePrefix).Subrouter() { prefixedRouter.HandleFunc(fmt.Sprintf("/{name:%s}/tags/list", NameRegexp.String()), - rh.ListTags).Methods("GET") + rh.ListTags).Methods(allowedMethods("GET")...) prefixedRouter.HandleFunc(fmt.Sprintf("/{name:%s}/manifests/{reference}", NameRegexp.String()), - rh.CheckManifest).Methods("HEAD") + rh.CheckManifest).Methods(allowedMethods("HEAD")...) prefixedRouter.HandleFunc(fmt.Sprintf("/{name:%s}/manifests/{reference}", NameRegexp.String()), - rh.GetManifest).Methods("GET") + rh.GetManifest).Methods(allowedMethods("GET")...) prefixedRouter.HandleFunc(fmt.Sprintf("/{name:%s}/manifests/{reference}", NameRegexp.String()), rh.UpdateManifest).Methods("PUT") prefixedRouter.HandleFunc(fmt.Sprintf("/{name:%s}/manifests/{reference}", NameRegexp.String()), @@ -94,9 +98,9 @@ func (rh *RouteHandler) SetupRoutes() { prefixedRouter.HandleFunc(fmt.Sprintf("/{name:%s}/blobs/uploads/{session_id}", NameRegexp.String()), rh.DeleteBlobUpload).Methods("DELETE") prefixedRouter.HandleFunc("/_catalog", - rh.ListRepositories).Methods("GET") + rh.ListRepositories).Methods(allowedMethods("GET")...) prefixedRouter.HandleFunc("/", - rh.CheckVersionSupport).Methods("GET") + rh.CheckVersionSupport).Methods(allowedMethods("GET")...) } // support for oras artifact reference types (alpha 1) - image signature use case diff --git a/pkg/extensions/extensions.go b/pkg/extensions/extensions.go index 1c61ea7c..9ee5a5cb 100644 --- a/pkg/extensions/extensions.go +++ b/pkg/extensions/extensions.go @@ -96,7 +96,7 @@ func SetupRoutes(config *config.Config, router *mux.Router, storeController stor resConfig = search.GetResolverConfig(log, storeController, false) } - router.PathPrefix("/query").Methods("GET", "POST"). + router.PathPrefix("/query").Methods("GET", "POST", "OPTIONS"). Handler(gqlHandler.NewDefaultServer(search.NewExecutableSchema(resConfig))) } diff --git a/pkg/extensions/search/cve/cve_test.go b/pkg/extensions/search/cve/cve_test.go index 061f7821..8430cae4 100644 --- a/pkg/extensions/search/cve/cve_test.go +++ b/pkg/extensions/search/cve/cve_test.go @@ -9,6 +9,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "net/http" "os" "path" "testing" @@ -669,3 +670,64 @@ func TestCVEConfig(t *testing.T) { }() }) } + +func TestHTTPOptionsResponse(t *testing.T) { + Convey("Test http options response", t, func() { + conf := config.New() + port := GetFreePort() + conf.HTTP.Port = port + baseURL := GetBaseURL(port) + + ctlr := api.NewController(conf) + + firstDir, err := ioutil.TempDir("", "oci-repo-test") + if err != nil { + panic(err) + } + + secondDir, err := ioutil.TempDir("", "oci-repo-test") + if err != nil { + panic(err) + } + defer os.RemoveAll(firstDir) + defer os.RemoveAll(secondDir) + + err = CopyFiles("../../../../test/data", path.Join(secondDir, "a")) + if err != nil { + panic(err) + } + + ctlr.Config.Storage.RootDirectory = firstDir + subPaths := make(map[string]config.StorageConfig) + subPaths["/a"] = config.StorageConfig{ + RootDirectory: secondDir, + } + + ctlr.Config.Storage.SubPaths = subPaths + + go func() { + // this blocks + if err := ctlr.Run(); err != nil { + return + } + }() + + // wait till ready + for { + _, err := resty.R().Get(baseURL) + if err == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + + resp, _ := resty.R().Options(baseURL + "/v2/_catalog") + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNoContent) + + defer func() { + ctx := context.Background() + _ = ctlr.Server.Shutdown(ctx) + }() + }) +}