test: refactor countingReader into partialReaderOpenTracker and partialReaderReadCloser (#4027)

countingReader was not respecting the single responsibility principle
and the implementation was hard to understand

Signed-off-by: Andrei Aaron <andreifdaaron@gmail.com>
This commit is contained in:
Andrei Aaron
2026-05-01 13:51:24 +03:00
committed by GitHub
parent cb9d682a69
commit 8f27949dcb
+55 -36
View File
@@ -2465,39 +2465,63 @@ func TestGetBlobMultipartPartHasDescriptorContentType(t *testing.T) {
// - A reader-error mid-stream truncates the body (since the 206
// headers have already been flushed) and is logged.
// countingReader wraps a strings.Reader so a test can observe whether
// the wrapper has been closed yet. It tracks open/max-open counters
// shared with the test; the storage mock invokes its constructor on
// every GetBlobPartial call, so any concurrent opens immediately
// surface as a maxOpen > 1.
type countingReader struct {
// partialReaderOpenTracker records how many partial-blob readers are open at once and
// the peak concurrent count. The multipart test's GetBlobPartial mock calls NewReadCloser
// per range; overlapping opens show up as PeakOpens() > 1.
type partialReaderOpenTracker struct {
live atomic.Int32
peak atomic.Int32
}
// NewReadCloser returns a reader that registers in the tracker until Close.
func (t *partialReaderOpenTracker) NewReadCloser(body string) io.ReadCloser {
t.beginOpen()
return &partialReaderReadCloser{
Reader: strings.NewReader(body),
tracker: t,
}
}
func (t *partialReaderOpenTracker) LiveOpens() int32 { return t.live.Load() }
func (t *partialReaderOpenTracker) PeakOpens() int32 { return t.peak.Load() }
func (t *partialReaderOpenTracker) endClose() { t.live.Add(-1) }
// beginOpen increments the live-open count and sets peak := max(peak, newLiveCount).
//
// The for loop retries when CompareAndSwap fails: another goroutine can change peak
// after Load but before CompareAndSwap, so one attempt is not enough under contention.
func (t *partialReaderOpenTracker) beginOpen() {
cur := t.live.Add(1)
for {
observedPeak := t.peak.Load()
if cur <= observedPeak {
return
}
if t.peak.CompareAndSwap(observedPeak, cur) {
return
}
}
}
// partialReaderReadCloser wraps a strings.Reader and only notifies the tracker on Close.
type partialReaderReadCloser struct {
*strings.Reader
open *atomic.Int32
maxOpen *atomic.Int32
tracker *partialReaderOpenTracker
closed bool
}
func newCountingReader(body string, open, maxOpen *atomic.Int32) *countingReader {
cur := open.Add(1)
for {
prev := maxOpen.Load()
if cur <= prev || maxOpen.CompareAndSwap(prev, cur) {
break
}
}
return &countingReader{Reader: strings.NewReader(body), open: open, maxOpen: maxOpen}
}
func (cr *countingReader) Close() error {
if cr.closed {
func (r *partialReaderReadCloser) Close() error {
if r.closed {
return nil
}
cr.closed = true
cr.open.Add(-1)
r.closed = true
r.tracker.endClose()
return nil
}
@@ -2583,10 +2607,7 @@ func TestGetBlobMultipartContentLengthMatchesBody(t *testing.T) {
func TestGetBlobMultipartOpensOneReaderAtATime(t *testing.T) {
const blobBody = "0123456789abcdef0123456789abcdef" // 32 bytes
var (
open atomic.Int32
maxOpen atomic.Int32
)
var opens partialReaderOpenTracker
store := descriptorStore(t)
store.CheckBlobFn = func(repo string, digest godigest.Digest) (bool, int64, error) {
@@ -2599,11 +2620,9 @@ func TestGetBlobMultipartOpensOneReaderAtATime(t *testing.T) {
from,
to int64,
) (io.ReadCloser, int64, int64, error) {
// Wrap a strings.Reader in a counter that increments on open
// and decrements on close. The producer goroutine in
// writeMultipartRanges should open and fully consume each
// reader before opening the next.
reader := newCountingReader(blobBody[from:to+1], &open, &maxOpen)
// opens tracks live readers; Close decrements. writeMultipartRanges should fully
// consume each reader before opening the next.
reader := opens.NewReadCloser(blobBody[from : to+1])
return reader, to - from + 1, int64(len(blobBody)), nil
}
@@ -2633,8 +2652,8 @@ func TestGetBlobMultipartOpensOneReaderAtATime(t *testing.T) {
// the open counter on every reader.
_ = drainResponseBody(t, resp)
assert.Equal(t, int32(0), open.Load(), "all readers must be closed by the time the body is drained")
assert.Equal(t, int32(1), maxOpen.Load(),
assert.Equal(t, int32(0), opens.LiveOpens(), "all readers must be closed by the time the body is drained")
assert.Equal(t, int32(1), opens.PeakOpens(),
"writeMultipartRanges must open at most one range reader at a time")
}