From 934b22d124151c1e9ef3e8793af8f4937d7e3bd9 Mon Sep 17 00:00:00 2001 From: Ramkumar Chinchani <45800463+rchincha@users.noreply.github.com> Date: Sun, 26 Apr 2026 12:23:48 -0700 Subject: [PATCH] =?UTF-8?q?fix(security):=20enhance=20timeout=20configurat?= =?UTF-8?q?ions=20and=20body=20size=20limits=20fo=E2=80=A6=20(#3984)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(security): enhance timeout configurations and body size limits for HTTP requests Signed-off-by: Ramkumar Chinchani * fix(tests): refactor backend result handling in proxyHTTPRequest test Signed-off-by: Ramkumar Chinchani * fix(security): preserve ContentLength in proxied requests to prevent server hang Signed-off-by: Ramkumar Chinchani * fix(security): preserve explicit zero-length request bodies in proxyHTTPRequest fix(tests): add test for normalizedTimeout function to ensure default fallback Signed-off-by: Ramkumar Chinchani * fix(security): prevent default HTTP timeout values from being set unless explicitly configured Signed-off-by: Ramkumar Chinchani * fix(security): refactor timeout handling to use explicit checks for nil and non-positive values Signed-off-by: Ramkumar Chinchani * fix(tests): add wait_for_event_count function to ensure expected event generation Signed-off-by: Ramkumar Chinchani * fix(security): improve timeout handling and update error responses for large requests Signed-off-by: Ramkumar Chinchani * fix(security): enhance HTTP timeout handling with explicit accessors and default values Signed-off-by: Ramkumar Chinchani * fix(security): increase default API key body size and timeout values for improved performance Signed-off-by: Ramkumar Chinchani * fix(security): unify timeout handling by replacing specific read/write timeouts with a single default timeout Signed-off-by: Ramkumar Chinchani * fix(security): consolidate HTTP timeout accessors and enhance timeout handling Signed-off-by: Ramkumar Chinchani * fix(security): simplify HTTP timeout accessors and set default values for read/write timeouts Co-authored-by: Copilot Signed-off-by: Ramkumar Chinchani --------- Signed-off-by: Ramkumar Chinchani Co-authored-by: Copilot --- pkg/api/config/config.go | 56 ++++++++- pkg/api/config/config_test.go | 38 +++++++ pkg/api/constants/consts.go | 4 +- pkg/api/controller.go | 3 + pkg/api/proxy.go | 51 +++------ pkg/api/proxy_internal_test.go | 113 +++++++++++++++++++ pkg/cli/server/root.go | 15 +++ pkg/cli/server/root_test.go | 41 +++++++ pkg/exporter/api/config.go | 23 +++- pkg/exporter/api/exporter.go | 14 +++ pkg/exporter/api/exporter_internal_test.go | 23 ++++ pkg/extensions/extension_image_trust.go | 24 +++- pkg/extensions/extension_image_trust_test.go | 35 +++++- swagger/docs.go | 12 ++ swagger/swagger.json | 12 ++ swagger/swagger.yaml | 8 ++ test/blackbox/events_http_lint_failure.bats | 24 ++++ 17 files changed, 440 insertions(+), 56 deletions(-) create mode 100644 pkg/api/proxy_internal_test.go create mode 100644 pkg/exporter/api/exporter_internal_test.go diff --git a/pkg/api/config/config.go b/pkg/api/config/config.go index f74cfab5..ae4112ac 100644 --- a/pkg/api/config/config.go +++ b/pkg/api/config/config.go @@ -376,10 +376,18 @@ type RatelimitConfig struct { //nolint:maligned type HTTPConfig struct { - Address string - ExternalURL string `mapstructure:",omitempty"` - Port string - AllowOrigin string // comma separated + Address string + ExternalURL string `mapstructure:",omitempty"` + Port string + AllowOrigin string // comma separated + // ReadTimeout controls maximum duration for reading the entire request (including body). + // When unset (nil), server-level defaults may apply. When explicitly set to <= 0, + // the HTTP server treats it as no timeout. + ReadTimeout *time.Duration `mapstructure:"readTimeout,omitempty"` + // WriteTimeout controls maximum duration before timing out response writes. + // When unset (nil), server-level defaults may apply. When explicitly set to <= 0, + // the HTTP server treats it as no timeout. + WriteTimeout *time.Duration `mapstructure:"writeTimeout,omitempty"` TLS *TLSConfig Auth *AuthConfig AccessControl *AccessControlConfig `mapstructure:"accessControl,omitempty"` @@ -661,8 +669,12 @@ func New() *Config { Retention: ImageRetention{}, }, }, - HTTP: HTTPConfig{Address: "127.0.0.1", Port: "8080", Auth: &AuthConfig{FailDelay: 0}}, - Log: &LogConfig{Level: "debug"}, + HTTP: HTTPConfig{ + Address: "127.0.0.1", + Port: "8080", + Auth: &AuthConfig{FailDelay: 0}, + }, + Log: &LogConfig{Level: "debug"}, } } @@ -1117,6 +1129,38 @@ func (c *Config) GetHTTPPort() string { return c.HTTP.Port } +// GetHTTPReadTimeout returns the configured HTTP server read timeout. +func (c *Config) GetHTTPReadTimeout() time.Duration { + if c == nil { + return 0 + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.HTTP.ReadTimeout == nil { + return 0 + } + + return *c.HTTP.ReadTimeout +} + +// GetHTTPWriteTimeout returns the configured HTTP server write timeout. +func (c *Config) GetHTTPWriteTimeout() time.Duration { + if c == nil { + return 0 + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.HTTP.WriteTimeout == nil { + return 0 + } + + return *c.HTTP.WriteTimeout +} + // GetAllowOrigin returns the CORS allow origin configuration. func (c *Config) GetAllowOrigin() string { if c == nil { diff --git a/pkg/api/config/config_test.go b/pkg/api/config/config_test.go index 0634013d..8debfbd6 100644 --- a/pkg/api/config/config_test.go +++ b/pkg/api/config/config_test.go @@ -3364,3 +3364,41 @@ func TestConfig(t *testing.T) { }) }) } + +func TestHTTPTimeoutAccessors(t *testing.T) { + Convey("GetHTTPReadTimeout returns configured values", t, func() { + cfg := config.New() + + So(cfg.GetHTTPReadTimeout(), ShouldEqual, 0) + + zero := time.Duration(0) + cfg.HTTP.ReadTimeout = &zero + So(cfg.GetHTTPReadTimeout(), ShouldEqual, 0) + + negative := -5 * time.Second + cfg.HTTP.ReadTimeout = &negative + So(cfg.GetHTTPReadTimeout(), ShouldEqual, negative) + + positive := 45 * time.Second + cfg.HTTP.ReadTimeout = &positive + So(cfg.GetHTTPReadTimeout(), ShouldEqual, positive) + }) + + Convey("GetHTTPWriteTimeout returns configured values", t, func() { + cfg := config.New() + + So(cfg.GetHTTPWriteTimeout(), ShouldEqual, 0) + + zero := time.Duration(0) + cfg.HTTP.WriteTimeout = &zero + So(cfg.GetHTTPWriteTimeout(), ShouldEqual, 0) + + negative := -5 * time.Second + cfg.HTTP.WriteTimeout = &negative + So(cfg.GetHTTPWriteTimeout(), ShouldEqual, negative) + + positive := 1 * time.Minute + cfg.HTTP.WriteTimeout = &positive + So(cfg.GetHTTPWriteTimeout(), ShouldEqual, positive) + }) +} diff --git a/pkg/api/constants/consts.go b/pkg/api/constants/consts.go index 37660942..e6db2e74 100644 --- a/pkg/api/constants/consts.go +++ b/pkg/api/constants/consts.go @@ -24,7 +24,9 @@ const ( // OCI manifest JSON is always small metadata; 4 MiB is well above any realistic manifest. MaxManifestBodySize = 4 * 1024 * 1024 // MaxAPIKeyBodySize is the maximum number of bytes accepted for an API-key creation request body. - MaxAPIKeyBodySize = 4 * 1024 + MaxAPIKeyBodySize = 8 * 1024 + // MaxImageTrustBodySize is the maximum number of bytes accepted for image-trust key/certificate uploads. + MaxImageTrustBodySize = 8 * 1024 * 1024 BlobUploadUUID = "Blob-Upload-UUID" DefaultMediaType = "application/json" BinaryMediaType = "application/octet-stream" diff --git a/pkg/api/controller.go b/pkg/api/controller.go index 8e875ece..ed776ebb 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -172,9 +172,12 @@ func (c *Controller) Run() error { port := c.Config.GetHTTPPort() addr := fmt.Sprintf("%s:%s", c.Config.GetHTTPAddress(), port) + server := &http.Server{ Addr: addr, Handler: c.Router, + ReadTimeout: c.Config.GetHTTPReadTimeout(), + WriteTimeout: c.Config.GetHTTPWriteTimeout(), IdleTimeout: idleTimeout, ReadHeaderTimeout: readHeaderTimeout, } diff --git a/pkg/api/proxy.go b/pkg/api/proxy.go index cac55943..8b4de8a9 100644 --- a/pkg/api/proxy.go +++ b/pkg/api/proxy.go @@ -1,7 +1,6 @@ package api import ( - "bytes" "context" "fmt" "io" @@ -131,15 +130,28 @@ func proxyHTTPRequest(ctx context.Context, req *http.Request, cloneURL.Scheme = proxyQueryScheme cloneURL.Host = targetMember - clonedBody := cloneRequestBody(req) + requestBody := io.Reader(http.NoBody) + if req.Body != nil { + requestBody = req.Body + } - fwdRequest, err := http.NewRequestWithContext(ctx, req.Method, cloneURL.String(), clonedBody) + fwdRequest, err := http.NewRequestWithContext(ctx, req.Method, cloneURL.String(), requestBody) if err != nil { return nil, err } copyHeader(fwdRequest.Header, req.Header) + // Preserve ContentLength from original request, including explicit zero-length + // bodies, so empty requests are not forwarded as unknown-length chunked bodies. + if req.ContentLength >= 0 { + fwdRequest.ContentLength = req.ContentLength + + if req.ContentLength == 0 { + fwdRequest.Body = http.NoBody + } + } + // always set hop count to 1 for now. // the handler wrapper above will terminate the process if it sees a request that // already has a hop count but is due for proxying. @@ -171,42 +183,9 @@ func proxyHTTPRequest(ctx context.Context, req *http.Request, return nil, err } - var clonedRespBody bytes.Buffer - - // copy out the contents into a new buffer as the response body - // stream should be closed to get all the data out. - _, _ = io.Copy(&clonedRespBody, resp.Body) - resp.Body.Close() - - // after closing the original body, substitute it with a new reader - // using the buffer that was just created. - // this buffer should be closed later by the consumer of the response. - resp.Body = io.NopCloser(bytes.NewReader(clonedRespBody.Bytes())) - return resp, nil } -func cloneRequestBody(src *http.Request) io.Reader { - var bCloneForOriginal, bCloneForCopy bytes.Buffer - multiWriter := io.MultiWriter(&bCloneForOriginal, &bCloneForCopy) - numBytesCopied, _ := io.Copy(multiWriter, src.Body) - - // if the body is a type of io.NopCloser and length is 0, - // the Content-Length header is not sent in the proxied request. - // explicitly returning http.NoBody allows the implementation - // to set the header. - // ref: https://github.com/golang/go/issues/34295 - if numBytesCopied == 0 { - src.Body = http.NoBody - - return http.NoBody - } - - src.Body = io.NopCloser(&bCloneForOriginal) - - return bytes.NewReader(bCloneForCopy.Bytes()) -} - func copyHeader(dst, src http.Header) { for k, vv := range src { for _, v := range vv { diff --git a/pkg/api/proxy_internal_test.go b/pkg/api/proxy_internal_test.go new file mode 100644 index 00000000..88898e1a --- /dev/null +++ b/pkg/api/proxy_internal_test.go @@ -0,0 +1,113 @@ +package api + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + . "github.com/smartystreets/goconvey/convey" + + "zotregistry.dev/zot/v2/pkg/api/config" + "zotregistry.dev/zot/v2/pkg/api/constants" +) + +func TestProxyHTTPRequestStreamsBodyAndResponse(t *testing.T) { + Convey("proxyHTTPRequest forwards request body/headers and returns streamed response", t, func() { + requestPayload := strings.Repeat("payload-", 1024) + responsePayload := strings.Repeat("response-", 2048) + + type backendResult struct { + body string + hopCount string + err error + } + + resultCh := make(chan backendResult, 1) + + backend := httptest.NewServer(http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + body, err := io.ReadAll(request.Body) + resultCh <- backendResult{ + body: string(body), + hopCount: request.Header.Get(constants.ScaleOutHopCountHeader), + err: err, + } + + response.WriteHeader(http.StatusCreated) + _, _ = io.WriteString(response, responsePayload) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + So(err, ShouldBeNil) + + conf := config.New() + conf.Cluster = &config.ClusterConfig{Members: []string{backendURL.Host}, HashKey: "loremipsumdolors"} + + ctrlr := &Controller{Config: conf} + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPut, + "http://example.com/v2/repo/manifests/latest", strings.NewReader(requestPayload)) + So(err, ShouldBeNil) + + resp, err := proxyHTTPRequest(context.Background(), req, backendURL.Host, ctrlr) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + So(err, ShouldBeNil) + + result := <-resultCh + So(result.err, ShouldBeNil) + + remainingReqBody, err := io.ReadAll(req.Body) + So(err, ShouldBeNil) + + So(resp.StatusCode, ShouldEqual, http.StatusCreated) + So(string(respBody), ShouldEqual, responsePayload) + So(result.body, ShouldEqual, requestPayload) + So(result.hopCount, ShouldEqual, "1") + So(len(remainingReqBody), ShouldEqual, 0) + }) +} + +func TestProxyHTTPRequestPreservesExplicitEmptyBody(t *testing.T) { + Convey("proxyHTTPRequest preserves explicit zero-length request bodies", t, func() { + resultCh := make(chan *http.Request, 1) + + backend := httptest.NewServer(http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + resultCh <- request + response.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + So(err, ShouldBeNil) + + conf := config.New() + conf.Cluster = &config.ClusterConfig{Members: []string{backendURL.Host}, HashKey: "loremipsumdolors"} + + ctrlr := &Controller{Config: conf} + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, + "http://example.com/v2/repo/manifests/latest", http.NoBody) + So(err, ShouldBeNil) + So(req.ContentLength, ShouldEqual, 0) + + resp, err := proxyHTTPRequest(context.Background(), req, backendURL.Host, ctrlr) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + defer resp.Body.Close() + + backendReq := <-resultCh + + So(resp.StatusCode, ShouldEqual, http.StatusNoContent) + So(backendReq.ContentLength, ShouldEqual, 0) + So(backendReq.Body, ShouldEqual, http.NoBody) + So(backendReq.TransferEncoding, ShouldBeEmpty) + }) +} diff --git a/pkg/cli/server/root.go b/pkg/cli/server/root.go index 815bc9db..d16f9bcb 100644 --- a/pkg/cli/server/root.go +++ b/pkg/cli/server/root.go @@ -33,6 +33,11 @@ import ( storageConstants "zotregistry.dev/zot/v2/pkg/storage/constants" ) +const ( + defaultReadTimeout = 60 * time.Second + defaultWriteTimeout = 60 * time.Second +) + // metadataConfig reports metadata after parsing, which we use to track // errors. func metadataConfig(md *mapstructure.Metadata) viper.DecoderConfigOption { @@ -1063,6 +1068,16 @@ func applyDefaultValues(config *config.Config, viperInstance *viper.Viper, logge config.Storage.SubPaths[name] = storageConfig } + if config.HTTP.ReadTimeout == nil { + readTimeout := defaultReadTimeout + config.HTTP.ReadTimeout = &readTimeout + } + + if config.HTTP.WriteTimeout == nil { + writeTimeout := defaultWriteTimeout + config.HTTP.WriteTimeout = &writeTimeout + } + // if OpenID authentication is enabled, // API Keys are also enabled in order to provide data path authentication if config.HTTP.Auth != nil && config.HTTP.Auth.OpenID != nil { diff --git a/pkg/cli/server/root_test.go b/pkg/cli/server/root_test.go index 3037ff99..98e511f7 100644 --- a/pkg/cli/server/root_test.go +++ b/pkg/cli/server/root_test.go @@ -87,6 +87,47 @@ func TestServerUsage(t *testing.T) { }) } +func TestLoadConfigurationInjectsHTTPTimeoutDefaults(t *testing.T) { + Convey("load config sets HTTP read/write timeout defaults when not explicitly configured", t, func() { + content := `{ + "storage": {"rootDirectory": "/tmp/zot"}, + "http": {"address": "127.0.0.1", "port": "8080"} + }` + + tmpfile := MakeTempFileWithContent(t, "zot-http-timeouts-unset.json", content) + cfg := config.New() + + err := cli.LoadConfiguration(cfg, tmpfile) + So(err, ShouldBeNil) + So(cfg.HTTP.ReadTimeout, ShouldNotBeNil) + So(cfg.HTTP.WriteTimeout, ShouldNotBeNil) + So(cfg.GetHTTPReadTimeout(), ShouldEqual, 60*time.Second) + So(cfg.GetHTTPWriteTimeout(), ShouldEqual, 60*time.Second) + }) + + Convey("load config preserves explicit HTTP read/write timeout values", t, func() { + content := `{ + "storage": {"rootDirectory": "/tmp/zot"}, + "http": { + "address": "127.0.0.1", + "port": "8080", + "readTimeout": "45s", + "writeTimeout": "1m" + } + }` + + tmpfile := MakeTempFileWithContent(t, "zot-http-timeouts-explicit.json", content) + cfg := config.New() + + err := cli.LoadConfiguration(cfg, tmpfile) + So(err, ShouldBeNil) + So(cfg.HTTP.ReadTimeout, ShouldNotBeNil) + So(cfg.HTTP.WriteTimeout, ShouldNotBeNil) + So(cfg.GetHTTPReadTimeout(), ShouldEqual, 45*time.Second) + So(cfg.GetHTTPWriteTimeout(), ShouldEqual, time.Minute) + }) +} + func TestSchema(t *testing.T) { Convey("Test schema command", t, func(c C) { cmd := cli.NewServerRootCmd() diff --git a/pkg/exporter/api/config.go b/pkg/exporter/api/config.go index 8744ffaa..856965ea 100644 --- a/pkg/exporter/api/config.go +++ b/pkg/exporter/api/config.go @@ -2,6 +2,8 @@ package api +import "time" + // LogConfig and the other types below are exported so the cli package can read them from configuration file. type LogConfig struct { Level string @@ -23,9 +25,11 @@ type ServerConfig struct { } type ExporterConfig struct { - Port string - Log *LogConfig - Metrics *MetricsConfig + Port string + ReadTimeout *time.Duration `mapstructure:"readTimeout,omitempty"` + WriteTimeout *time.Duration `mapstructure:"writeTimeout,omitempty"` + Log *LogConfig + Metrics *MetricsConfig } type Config struct { @@ -34,8 +38,17 @@ type Config struct { } func DefaultConfig() *Config { + readTimeout := defaultTimeout + writeTimeout := defaultTimeout + return &Config{ - Server: ServerConfig{Protocol: "http", Host: "localhost", Port: "8080"}, - Exporter: ExporterConfig{Port: "8081", Log: &LogConfig{Level: "debug"}, Metrics: &MetricsConfig{Path: "/metrics"}}, + Server: ServerConfig{Protocol: "http", Host: "localhost", Port: "8080"}, + Exporter: ExporterConfig{ + Port: "8081", + ReadTimeout: &readTimeout, + WriteTimeout: &writeTimeout, + Log: &LogConfig{Level: "debug"}, + Metrics: &MetricsConfig{Path: "/metrics"}, + }, } } diff --git a/pkg/exporter/api/exporter.go b/pkg/exporter/api/exporter.go index a3a8d367..e3a7b129 100644 --- a/pkg/exporter/api/exporter.go +++ b/pkg/exporter/api/exporter.go @@ -21,6 +21,7 @@ import ( const ( idleTimeout = 120 * time.Second readHeaderTimeout = 5 * time.Second + defaultTimeout = 30 * time.Second ) type Collector struct { @@ -169,10 +170,23 @@ func GetCollector(c *Controller) *Collector { } } +func selectedTimeout(configured *time.Duration) time.Duration { + if configured != nil && *configured > 0 { + return *configured + } + + return defaultTimeout +} + func runExporter(c *Controller) { exporterAddr := ":" + c.Config.Exporter.Port + readTimeout := selectedTimeout(c.Config.Exporter.ReadTimeout) + writeTimeout := selectedTimeout(c.Config.Exporter.WriteTimeout) + server := &http.Server{ Addr: exporterAddr, + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, IdleTimeout: idleTimeout, ReadHeaderTimeout: readHeaderTimeout, } diff --git a/pkg/exporter/api/exporter_internal_test.go b/pkg/exporter/api/exporter_internal_test.go new file mode 100644 index 00000000..d3e0fed9 --- /dev/null +++ b/pkg/exporter/api/exporter_internal_test.go @@ -0,0 +1,23 @@ +//go:build !metrics + +package api + +import ( + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestExporterTimeoutSelection(t *testing.T) { + Convey("exporter uses the provided default timeout when configured value is nil or non-positive", t, func() { + positive := 10 * time.Second + zero := time.Duration(0) + negative := -5 * time.Second + + So(selectedTimeout(nil), ShouldEqual, defaultTimeout) + So(selectedTimeout(&zero), ShouldEqual, defaultTimeout) + So(selectedTimeout(&negative), ShouldEqual, defaultTimeout) + So(selectedTimeout(&positive), ShouldEqual, positive) + }) +} diff --git a/pkg/extensions/extension_image_trust.go b/pkg/extensions/extension_image_trust.go index 8bf33945..2d248419 100644 --- a/pkg/extensions/extension_image_trust.go +++ b/pkg/extensions/extension_image_trust.go @@ -82,12 +82,18 @@ type ImageTrust struct { // @Param requestBody body string true "Public key content" // @Success 200 {string} string "ok" // @Failure 400 {string} string "bad request" +// @Failure 413 {string} string "request entity too large" // @Failure 500 {string} string "internal server error" func (trust *ImageTrust) HandleCosignPublicKeyUpload(response http.ResponseWriter, request *http.Request) { - body, err := io.ReadAll(request.Body) + body, err := io.ReadAll(http.MaxBytesReader(response, request.Body, constants.MaxImageTrustBodySize)) if err != nil { - trust.Log.Error().Err(err).Str("component", "image-trust").Msg("failed to read cosign key body") - response.WriteHeader(http.StatusInternalServerError) + var mbe *http.MaxBytesError + if errors.As(err, &mbe) { + response.WriteHeader(http.StatusRequestEntityTooLarge) + } else { + trust.Log.Error().Err(err).Str("component", "image-trust").Msg("failed to read cosign key body") + response.WriteHeader(http.StatusInternalServerError) + } return } @@ -117,6 +123,7 @@ func (trust *ImageTrust) HandleCosignPublicKeyUpload(response http.ResponseWrite // @Param requestBody body string true "Certificate content" // @Success 200 {string} string "ok" // @Failure 400 {string} string "bad request" +// @Failure 413 {string} string "request entity too large" // @Failure 500 {string} string "internal server error" func (trust *ImageTrust) HandleNotationCertificateUpload(response http.ResponseWriter, request *http.Request) { var truststoreType string @@ -127,10 +134,15 @@ func (trust *ImageTrust) HandleNotationCertificateUpload(response http.ResponseW truststoreType = "ca" // default value of "truststoreType" query param } - body, err := io.ReadAll(request.Body) + body, err := io.ReadAll(http.MaxBytesReader(response, request.Body, constants.MaxImageTrustBodySize)) if err != nil { - trust.Log.Error().Err(err).Str("component", "image-trust").Msg("failed to read notation certificate body") - response.WriteHeader(http.StatusInternalServerError) + var mbe *http.MaxBytesError + if errors.As(err, &mbe) { + response.WriteHeader(http.StatusRequestEntityTooLarge) + } else { + trust.Log.Error().Err(err).Str("component", "image-trust").Msg("failed to read notation certificate body") + response.WriteHeader(http.StatusInternalServerError) + } return } diff --git a/pkg/extensions/extension_image_trust_test.go b/pkg/extensions/extension_image_trust_test.go index e7506937..e6f01cf9 100644 --- a/pkg/extensions/extension_image_trust_test.go +++ b/pkg/extensions/extension_image_trust_test.go @@ -3,6 +3,7 @@ package extensions_test import ( + "bytes" "context" "encoding/json" "errors" @@ -56,7 +57,8 @@ func TestSignatureHandlers(t *testing.T) { } Convey("Test error handling when Cosign handler reads the request body", t, func() { - request, _ := http.NewRequestWithContext(context.TODO(), http.MethodPost, "baseURL", errReader(0)) + request, err := http.NewRequestWithContext(context.TODO(), http.MethodPost, "http://example.com", errReader(0)) + So(err, ShouldBeNil) response := httptest.NewRecorder() trust.HandleCosignPublicKeyUpload(response, request) @@ -67,7 +69,8 @@ func TestSignatureHandlers(t *testing.T) { }) Convey("Test error handling when Notation handler reads the request body", t, func() { - request, _ := http.NewRequestWithContext(context.TODO(), http.MethodPost, "baseURL", errReader(0)) + request, err := http.NewRequestWithContext(context.TODO(), http.MethodPost, "http://example.com", errReader(0)) + So(err, ShouldBeNil) query := request.URL.Query() request.URL.RawQuery = query.Encode() @@ -78,6 +81,34 @@ func TestSignatureHandlers(t *testing.T) { defer resp.Body.Close() So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError) }) + + Convey("Test cosign upload body over max size returns 413", t, func() { + overSizedBody := make([]byte, constants.MaxImageTrustBodySize+1) + request, err := http.NewRequestWithContext(context.TODO(), http.MethodPost, + "http://example.com", bytes.NewReader(overSizedBody)) + So(err, ShouldBeNil) + response := httptest.NewRecorder() + + trust.HandleCosignPublicKeyUpload(response, request) + + resp := response.Result() + defer resp.Body.Close() + So(resp.StatusCode, ShouldEqual, http.StatusRequestEntityTooLarge) + }) + + Convey("Test notation upload body over max size returns 413", t, func() { + overSizedBody := make([]byte, constants.MaxImageTrustBodySize+1) + request, err := http.NewRequestWithContext(context.TODO(), http.MethodPost, + "http://example.com", bytes.NewReader(overSizedBody)) + So(err, ShouldBeNil) + response := httptest.NewRecorder() + + trust.HandleNotationCertificateUpload(response, request) + + resp := response.Result() + defer resp.Body.Close() + So(resp.StatusCode, ShouldEqual, http.StatusRequestEntityTooLarge) + }) } func TestSignaturesAllowedMethodsHeader(t *testing.T) { diff --git a/swagger/docs.go b/swagger/docs.go index 871d25bd..05b8c126 100644 --- a/swagger/docs.go +++ b/swagger/docs.go @@ -119,6 +119,12 @@ const docTemplate = `{ "type": "string" } }, + "413": { + "description": "request entity too large", + "schema": { + "type": "string" + } + }, "500": { "description": "internal server error", "schema": { @@ -205,6 +211,12 @@ const docTemplate = `{ "type": "string" } }, + "413": { + "description": "request entity too large", + "schema": { + "type": "string" + } + }, "500": { "description": "internal server error", "schema": { diff --git a/swagger/swagger.json b/swagger/swagger.json index 2fee657f..699782ef 100644 --- a/swagger/swagger.json +++ b/swagger/swagger.json @@ -111,6 +111,12 @@ "type": "string" } }, + "413": { + "description": "request entity too large", + "schema": { + "type": "string" + } + }, "500": { "description": "internal server error", "schema": { @@ -197,6 +203,12 @@ "type": "string" } }, + "413": { + "description": "request entity too large", + "schema": { + "type": "string" + } + }, "500": { "description": "internal server error", "schema": { diff --git a/swagger/swagger.yaml b/swagger/swagger.yaml index 09f78efb..cd236e3e 100644 --- a/swagger/swagger.yaml +++ b/swagger/swagger.yaml @@ -318,6 +318,10 @@ paths: description: bad request schema: type: string + "413": + description: request entity too large + schema: + type: string "500": description: internal server error schema: @@ -374,6 +378,10 @@ paths: description: bad request schema: type: string + "413": + description: request entity too large + schema: + type: string "500": description: internal server error schema: diff --git a/test/blackbox/events_http_lint_failure.bats b/test/blackbox/events_http_lint_failure.bats index 084fff28..5d13ad76 100644 --- a/test/blackbox/events_http_lint_failure.bats +++ b/test/blackbox/events_http_lint_failure.bats @@ -94,6 +94,28 @@ function teardown_file() { http_server_stop http_receiver_lint } +function wait_for_event_count() { + local output_path="$1" + local expected_count="$2" + local timeout_seconds="${3:-10}" + local elapsed=0 + local count=0 + + while [ "$elapsed" -lt "$timeout_seconds" ]; do + count=$(find "${output_path}" -type f | wc -l) + if [ "$count" -eq "$expected_count" ]; then + return 0 + fi + + sleep 1 + elapsed=$((elapsed + 1)) + done + + echo "timed out waiting for ${expected_count} events, found ${count}" >&3 + + return 1 +} + @test "http/publish image lint failure event" { http_server_port=$(cat ${BATS_FILE_TMPDIR}/http_server.port) zot_port=$(cat ${BATS_FILE_TMPDIR}/zot.port) @@ -117,6 +139,7 @@ function teardown_file() { rm -f artifact.txt config.json # Check the correct number of events were generated + wait_for_event_count "${output_path}" 2 count=$(find "${output_path}" -type f | wc -l) [ "$count" -eq 2 ] @@ -152,6 +175,7 @@ function teardown_file() { rm -f artifact.txt config.json # Check the correct number of events were generated + wait_for_event_count "${output_path}" 1 count=$(find "${output_path}" -type f | wc -l) [ "$count" -eq 1 ]