diff --git a/pkg/api/controller_test.go b/pkg/api/controller_test.go index 2fb0ee21..d5298795 100644 --- a/pkg/api/controller_test.go +++ b/pkg/api/controller_test.go @@ -12,6 +12,8 @@ import ( goerrors "errors" "fmt" "io" + "mime" + "mime/multipart" "net" "net/http" "net/http/httptest" @@ -11285,6 +11287,54 @@ func TestPullRange(t *testing.T) { So(resp.Body(), ShouldResemble, content[2:4]) }) + Convey("Get a suffix range of bytes", func() { + resp, err = resty.R().SetHeader("Range", "bytes=-3").Get(blobLoc) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusPartialContent) + So(resp.Header().Get("Content-Length"), ShouldEqual, "3") + So(resp.Header().Get("Content-Range"), ShouldEqual, "bytes 7-9/10") + So(resp.Body(), ShouldResemble, content[7:10]) + + resp, err = resty.R().SetHeader("Range", "bytes=-100").Get(blobLoc) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusPartialContent) + So(resp.Header().Get("Content-Length"), ShouldEqual, strconv.Itoa(len(content))) + So(resp.Header().Get("Content-Range"), ShouldEqual, "bytes 0-9/10") + So(resp.Body(), ShouldResemble, content) + }) + + Convey("Get multiple ranges of bytes", func() { + resp, err = resty.R().SetHeader("Range", "bytes=0-1,4-6").Get(blobLoc) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusPartialContent) + + contentType, params, err := mime.ParseMediaType(resp.Header().Get("Content-Type")) + So(err, ShouldBeNil) + So(contentType, ShouldEqual, "multipart/byteranges") + So(params["boundary"], ShouldNotBeEmpty) + + multipartReader := multipart.NewReader(bytes.NewReader(resp.Body()), params["boundary"]) + + part, err := multipartReader.NextPart() + So(err, ShouldBeNil) + So(part.Header.Get("Content-Range"), ShouldEqual, "bytes 0-1/10") + + partBody, err := io.ReadAll(part) + So(err, ShouldBeNil) + So(partBody, ShouldResemble, content[0:2]) + + part, err = multipartReader.NextPart() + So(err, ShouldBeNil) + So(part.Header.Get("Content-Range"), ShouldEqual, "bytes 4-6/10") + + partBody, err = io.ReadAll(part) + So(err, ShouldBeNil) + So(partBody, ShouldResemble, content[4:7]) + + _, err = multipartReader.NextPart() + So(err, ShouldEqual, io.EOF) + }) + Convey("Negative cases", func() { resp, err = resty.R().SetHeader("Range", "=0").Get(blobLoc) So(err, ShouldBeNil) @@ -11353,6 +11403,11 @@ func TestPullRange(t *testing.T) { resp, err = resty.R().SetHeader("Range", "bytes=a-b").Get(blobLoc) So(err, ShouldBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusRequestedRangeNotSatisfiable) + + resp, err = resty.R().SetHeader("Range", "bytes=100-100").Get(blobLoc) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusRequestedRangeNotSatisfiable) + So(resp.Header().Get("Content-Range"), ShouldEqual, "bytes */10") }) }) } diff --git a/pkg/api/routes.go b/pkg/api/routes.go index 822aaa0f..eceb61ce 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -13,10 +13,11 @@ import ( "errors" "fmt" "io" + "mime/multipart" "net/http" + "net/textproto" "net/url" "path" - "regexp" "slices" "sort" "strconv" @@ -1121,51 +1122,164 @@ func (rh *RouteHandler) CheckBlob(response http.ResponseWriter, request *http.Re response.WriteHeader(http.StatusOK) } -/* parseRangeHeader validates the "Range" HTTP header and returns the range. */ -func parseRangeHeader(contentRange string) (int64, int64, error) { - /* bytes=- and bytes=- formats are supported */ - pattern := `bytes=(?P\d+)-(?P\d*$)` +type httpRange struct { + start int64 + end int64 +} - regex, err := regexp.Compile(pattern) - if err != nil { - return -1, -1, zerr.ErrParsingHTTPHeader +const maxRangeSpecCount = 16 + +func (r httpRange) length() int64 { + return r.end - r.start + 1 +} + +type blobRangeReader struct { + httpRange + + reader io.ReadCloser +} + +func closeRangeReaders(rangeReaders []blobRangeReader) { + for _, rangeReader := range rangeReaders { + _ = rangeReader.reader.Close() + } +} + +/* parseRangeHeader validates the "Range" HTTP header and returns normalized byte ranges. */ +func parseRangeHeader(contentRange string, size int64) ([]httpRange, error) { + if size <= 0 || !strings.HasPrefix(contentRange, "bytes=") { + return nil, zerr.ErrParsingHTTPHeader } - match := regex.FindStringSubmatch(contentRange) - - paramsMap := make(map[string]string) - - for i, name := range regex.SubexpNames() { - if i > 0 && i <= len(match) { - paramsMap[name] = match[i] - } + rangeSet := strings.TrimPrefix(contentRange, "bytes=") + if rangeSet == "" || strings.Count(rangeSet, ",")+1 > maxRangeSpecCount { + return nil, zerr.ErrParsingHTTPHeader } - var from int64 + rangeSpecs := strings.Split(rangeSet, ",") + ranges := make([]httpRange, 0, len(rangeSpecs)) - to := int64(-1) - - rangeFrom := paramsMap["rangeFrom"] - if rangeFrom == "" { - return -1, -1, zerr.ErrParsingHTTPHeader - } - - if from, err = strconv.ParseInt(rangeFrom, 10, 64); err != nil { - return -1, -1, zerr.ErrParsingHTTPHeader - } - - rangeTo := paramsMap["rangeTo"] - if rangeTo != "" { - if to, err = strconv.ParseInt(rangeTo, 10, 64); err != nil { - return -1, -1, zerr.ErrParsingHTTPHeader + for _, rangeSpec := range rangeSpecs { + rangeSpec = strings.TrimSpace(rangeSpec) + if rangeSpec == "" { + return nil, zerr.ErrParsingHTTPHeader } - if to < from { - return -1, -1, zerr.ErrParsingHTTPHeader + startStr, endStr, ok := strings.Cut(rangeSpec, "-") + if !ok { + return nil, zerr.ErrParsingHTTPHeader } + + var start, end int64 + + if startStr == "" { + suffixLen, err := strconv.ParseInt(endStr, 10, 64) + if err != nil || suffixLen <= 0 { + return nil, zerr.ErrParsingHTTPHeader + } + + if suffixLen > size { + start = 0 + } else { + start = size - suffixLen + } + + end = size - 1 + } else { + parsedStart, err := strconv.ParseInt(startStr, 10, 64) + if err != nil || parsedStart < 0 { + return nil, zerr.ErrParsingHTTPHeader + } + + start = parsedStart + + if endStr == "" { + end = size - 1 + } else { + parsedEnd, err := strconv.ParseInt(endStr, 10, 64) + if err != nil || parsedEnd < start { + return nil, zerr.ErrParsingHTTPHeader + } + + end = min(parsedEnd, size-1) + } + } + + if start >= size || start > end { + return nil, zerr.ErrParsingHTTPHeader + } + + ranges = append(ranges, httpRange{start: start, end: end}) } - return from, to, nil + if len(ranges) == 0 { + return nil, zerr.ErrParsingHTTPHeader + } + + return coalesceRanges(ranges), nil +} + +func coalesceRanges(ranges []httpRange) []httpRange { + sort.Slice(ranges, func(i, j int) bool { + if ranges[i].start == ranges[j].start { + return ranges[i].end < ranges[j].end + } + + return ranges[i].start < ranges[j].start + }) + + coalesced := ranges[:0] + + for _, httpRange := range ranges { + if len(coalesced) == 0 { + coalesced = append(coalesced, httpRange) + + continue + } + + lastRange := &coalesced[len(coalesced)-1] + if httpRange.start <= lastRange.end+1 { + lastRange.end = max(lastRange.end, httpRange.end) + + continue + } + + coalesced = append(coalesced, httpRange) + } + + return coalesced +} + +func writeMultipartRanges(response http.ResponseWriter, ranges []blobRangeReader, bsize int64, + logger log.Logger, +) { + writer := multipart.NewWriter(response) + defer func() { + if err := writer.Close(); err != nil { + logger.Error().Err(err).Msg("failed to close multipart range response") + } + }() + + response.Header().Set("Content-Type", "multipart/byteranges; boundary="+writer.Boundary()) + response.WriteHeader(http.StatusPartialContent) + + for _, rangeReader := range ranges { + partHeader := textproto.MIMEHeader{} + partHeader.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", rangeReader.start, rangeReader.end, bsize)) + + part, err := writer.CreatePart(partHeader) + if err != nil { + logger.Error().Err(err).Msg("failed to create multipart range response") + + return + } + + if _, err := io.Copy(part, rangeReader.reader); err != nil { + logger.Error().Err(err).Msg("failed to copy range into multipart response") + + return + } + } } // GetBlob godoc @@ -1202,44 +1316,10 @@ func (rh *RouteHandler) GetBlob(response http.ResponseWriter, request *http.Requ mediaType := request.Header.Get("Accept") - /* content range is supported for resumbale pulls */ - partial := false - - var from, to int64 - - var err error - contentRange := request.Header.Get("Range") + _, rangeHeaderPresent := request.Header["Range"] - _, ok = request.Header["Range"] - if ok && contentRange == "" { - response.WriteHeader(http.StatusRequestedRangeNotSatisfiable) - - return - } - - if contentRange != "" { - from, to, err = parseRangeHeader(contentRange) - if err != nil { - response.WriteHeader(http.StatusRequestedRangeNotSatisfiable) - - return - } - - partial = true - } - - var repo io.ReadCloser - - var blen, bsize int64 - - if partial { - repo, blen, bsize, err = imgStore.GetBlobPartial(name, digest, mediaType, from, to) - } else { - repo, blen, err = imgStore.GetBlob(name, digest, mediaType) - } - - if err != nil { + writeBlobError := func(err error) { details := zerr.GetDetails(err) if errors.Is(err, zerr.ErrBadBlobDigest) { //nolint:gocritic // errorslint conflicts with gocritic:IfElseChain details["digest"] = digest.String() @@ -1257,6 +1337,81 @@ func (rh *RouteHandler) GetBlob(response http.ResponseWriter, request *http.Requ rh.c.Log.Error().Err(err).Msg("unexpected error") response.WriteHeader(http.StatusInternalServerError) } + } + + if rangeHeaderPresent { + ok, bsize, err := imgStore.CheckBlob(name, digest) + if err != nil { + writeBlobError(err) + + return + } + + if !ok { + e := apiErr.NewError(apiErr.BLOB_UNKNOWN).AddDetail(map[string]string{"digest": digest.String()}) + zcommon.WriteJSON(response, http.StatusNotFound, apiErr.NewErrorList(e)) + + return + } + + ranges, err := parseRangeHeader(contentRange, bsize) + if err != nil { + response.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", bsize)) + response.WriteHeader(http.StatusRequestedRangeNotSatisfiable) + + return + } + + rangeReaders := make([]blobRangeReader, 0, len(ranges)) + defer func() { closeRangeReaders(rangeReaders) }() + + for _, httpRange := range ranges { + repo, blen, _, err := imgStore.GetBlobPartial(name, digest, mediaType, httpRange.start, httpRange.end) + if err != nil { + writeBlobError(err) + + return + } + + if blen != httpRange.length() { + _ = repo.Close() + + rh.c.Log.Error(). + Int64("expected", httpRange.length()). + Int64("actual", blen). + Msg("unexpected partial blob length") + response.WriteHeader(http.StatusInternalServerError) + + return + } + + rangeReaders = append(rangeReaders, blobRangeReader{httpRange: httpRange, reader: repo}) + } + + response.Header().Set(constants.DistContentDigestKey, digest.String()) + + if len(rangeReaders) > 1 { + writeMultipartRanges(response, rangeReaders, bsize, rh.c.Log) + + return + } + + rangeReader := rangeReaders[0] + response.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", rangeReader.start, rangeReader.end, bsize)) + WriteDataFromReader( + response, http.StatusPartialContent, rangeReader.length(), mediaType, rangeReader.reader, rh.c.Log, + ) + + return + } + + var repo io.ReadCloser + + var blen int64 + + repo, blen, err := imgStore.GetBlob(name, digest, mediaType) + if err != nil { + writeBlobError(err) return } @@ -1264,19 +1419,10 @@ func (rh *RouteHandler) GetBlob(response http.ResponseWriter, request *http.Requ defer repo.Close() response.Header().Set("Content-Length", strconv.FormatInt(blen, 10)) - - status := http.StatusOK - - if partial { - status = http.StatusPartialContent - - response.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", from, from+blen-1, bsize)) - } else { - response.Header().Set(constants.DistContentDigestKey, digest.String()) - } + response.Header().Set(constants.DistContentDigestKey, digest.String()) // return the blob data - WriteDataFromReader(response, status, blen, mediaType, repo, rh.c.Log) + WriteDataFromReader(response, http.StatusOK, blen, mediaType, repo, rh.c.Log) } // DeleteBlob godoc diff --git a/pkg/api/routes_internal_test.go b/pkg/api/routes_internal_test.go new file mode 100644 index 00000000..05400548 --- /dev/null +++ b/pkg/api/routes_internal_test.go @@ -0,0 +1,104 @@ +package api + +import ( + "reflect" + "strings" + "testing" +) + +func TestParseRangeHeader(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + header string + size int64 + want []httpRange + wantErr bool + }{ + { + name: "open ended range", + header: "bytes=0-", + size: 10, + want: []httpRange{{start: 0, end: 9}}, + }, + { + name: "range end is capped to size", + header: "bytes=0-100", + size: 10, + want: []httpRange{{start: 0, end: 9}}, + }, + { + name: "suffix range", + header: "bytes=-3", + size: 10, + want: []httpRange{{start: 7, end: 9}}, + }, + { + name: "oversized suffix range returns whole blob", + header: "bytes=-100", + size: 10, + want: []httpRange{{start: 0, end: 9}}, + }, + { + name: "ranges are sorted", + header: "bytes=7-8, 0-1", + size: 10, + want: []httpRange{ + {start: 0, end: 1}, + {start: 7, end: 8}, + }, + }, + { + name: "overlapping and adjacent ranges are coalesced", + header: "bytes=0-2,3-4,6-8,7-9", + size: 10, + want: []httpRange{ + {start: 0, end: 4}, + {start: 6, end: 9}, + }, + }, + {name: "zero size", header: "bytes=0-", wantErr: true}, + {name: "wrong unit", header: "byte=0-1", size: 10, wantErr: true}, + {name: "empty range set", header: "bytes=", size: 10, wantErr: true}, + {name: "empty range spec", header: "bytes=0-1,", size: 10, wantErr: true}, + {name: "zero suffix", header: "bytes=-0", size: 10, wantErr: true}, + {name: "bad suffix", header: "bytes=-x", size: 10, wantErr: true}, + {name: "bad start", header: "bytes=x-1", size: 10, wantErr: true}, + {name: "bad end", header: "bytes=1-x", size: 10, wantErr: true}, + {name: "inverted range", header: "bytes=2-1", size: 10, wantErr: true}, + {name: "range starts at size", header: "bytes=10-", size: 10, wantErr: true}, + {name: "range without dash", header: "bytes=0", size: 10, wantErr: true}, + { + name: "too many ranges", + header: "bytes=" + strings.TrimSuffix(strings.Repeat("0-0,", maxRangeSpecCount+1), ","), + size: 10, + wantErr: true, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + got, err := parseRangeHeader(test.header, test.size) + if test.wantErr { + if err == nil { + t.Fatal("expected parse error") + } + + return + } + + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + + if !reflect.DeepEqual(got, test.want) { + t.Fatalf("expected ranges %v, got %v", test.want, got) + } + }) + } +}