diff --git a/pkg/api/routes_test.go b/pkg/api/routes_test.go index 9ac6a765..0e490a02 100644 --- a/pkg/api/routes_test.go +++ b/pkg/api/routes_test.go @@ -2465,39 +2465,63 @@ func TestGetBlobMultipartPartHasDescriptorContentType(t *testing.T) { // - A reader-error mid-stream truncates the body (since the 206 // headers have already been flushed) and is logged. -// countingReader wraps a strings.Reader so a test can observe whether -// the wrapper has been closed yet. It tracks open/max-open counters -// shared with the test; the storage mock invokes its constructor on -// every GetBlobPartial call, so any concurrent opens immediately -// surface as a maxOpen > 1. -type countingReader struct { +// partialReaderOpenTracker records how many partial-blob readers are open at once and +// the peak concurrent count. The multipart test's GetBlobPartial mock calls NewReadCloser +// per range; overlapping opens show up as PeakOpens() > 1. +type partialReaderOpenTracker struct { + live atomic.Int32 + peak atomic.Int32 +} + +// NewReadCloser returns a reader that registers in the tracker until Close. +func (t *partialReaderOpenTracker) NewReadCloser(body string) io.ReadCloser { + t.beginOpen() + + return &partialReaderReadCloser{ + Reader: strings.NewReader(body), + tracker: t, + } +} + +func (t *partialReaderOpenTracker) LiveOpens() int32 { return t.live.Load() } + +func (t *partialReaderOpenTracker) PeakOpens() int32 { return t.peak.Load() } + +func (t *partialReaderOpenTracker) endClose() { t.live.Add(-1) } + +// beginOpen increments the live-open count and sets peak := max(peak, newLiveCount). +// +// The for loop retries when CompareAndSwap fails: another goroutine can change peak +// after Load but before CompareAndSwap, so one attempt is not enough under contention. +func (t *partialReaderOpenTracker) beginOpen() { + cur := t.live.Add(1) + + for { + observedPeak := t.peak.Load() + if cur <= observedPeak { + return + } + if t.peak.CompareAndSwap(observedPeak, cur) { + return + } + } +} + +// partialReaderReadCloser wraps a strings.Reader and only notifies the tracker on Close. +type partialReaderReadCloser struct { *strings.Reader - open *atomic.Int32 - maxOpen *atomic.Int32 + tracker *partialReaderOpenTracker closed bool } -func newCountingReader(body string, open, maxOpen *atomic.Int32) *countingReader { - cur := open.Add(1) - - for { - prev := maxOpen.Load() - if cur <= prev || maxOpen.CompareAndSwap(prev, cur) { - break - } - } - - return &countingReader{Reader: strings.NewReader(body), open: open, maxOpen: maxOpen} -} - -func (cr *countingReader) Close() error { - if cr.closed { +func (r *partialReaderReadCloser) Close() error { + if r.closed { return nil } - cr.closed = true - cr.open.Add(-1) + r.closed = true + r.tracker.endClose() return nil } @@ -2583,10 +2607,7 @@ func TestGetBlobMultipartContentLengthMatchesBody(t *testing.T) { func TestGetBlobMultipartOpensOneReaderAtATime(t *testing.T) { const blobBody = "0123456789abcdef0123456789abcdef" // 32 bytes - var ( - open atomic.Int32 - maxOpen atomic.Int32 - ) + var opens partialReaderOpenTracker store := descriptorStore(t) store.CheckBlobFn = func(repo string, digest godigest.Digest) (bool, int64, error) { @@ -2599,11 +2620,9 @@ func TestGetBlobMultipartOpensOneReaderAtATime(t *testing.T) { from, to int64, ) (io.ReadCloser, int64, int64, error) { - // Wrap a strings.Reader in a counter that increments on open - // and decrements on close. The producer goroutine in - // writeMultipartRanges should open and fully consume each - // reader before opening the next. - reader := newCountingReader(blobBody[from:to+1], &open, &maxOpen) + // opens tracks live readers; Close decrements. writeMultipartRanges should fully + // consume each reader before opening the next. + reader := opens.NewReadCloser(blobBody[from : to+1]) return reader, to - from + 1, int64(len(blobBody)), nil } @@ -2633,8 +2652,8 @@ func TestGetBlobMultipartOpensOneReaderAtATime(t *testing.T) { // the open counter on every reader. _ = drainResponseBody(t, resp) - assert.Equal(t, int32(0), open.Load(), "all readers must be closed by the time the body is drained") - assert.Equal(t, int32(1), maxOpen.Load(), + assert.Equal(t, int32(0), opens.LiveOpens(), "all readers must be closed by the time the body is drained") + assert.Equal(t, int32(1), opens.PeakOpens(), "writeMultipartRanges must open at most one range reader at a time") }