diff --git a/pkg/extensions/config/config_test.go b/pkg/extensions/config/config_test.go index 93dbd402..0688a892 100644 --- a/pkg/extensions/config/config_test.go +++ b/pkg/extensions/config/config_test.go @@ -67,6 +67,20 @@ func buildSyncConfig(enabled bool) *config.ExtensionConfig { return ext } +func buildSyncConfigWithStreaming(streamEnabled bool) *config.ExtensionConfig { + ext := &config.ExtensionConfig{} + ext.Sync = &sync.Config{ + Registries: []sync.RegistryConfig{ + { + URLs: []string{"localhost"}, + Stream: &streamEnabled, + }, + }, + } + + return ext +} + func buildScrubConfig(enabled bool) *config.ExtensionConfig { ext := &config.ExtensionConfig{} ext.Scrub = &config.ScrubConfig{ @@ -328,6 +342,60 @@ func TestExtensionConfig(t *testing.T) { testMethodWithEnabledEnable("Sync", (*config.ExtensionConfig).IsSyncEnabled, buildSyncConfig) }) + Convey("Test IsStreamingEnabled()", func() { + Convey("returns false with nil ExtensionConfig", func() { + var extensionConfig *config.ExtensionConfig + So(extensionConfig.IsStreamingEnabled(), ShouldBeFalse) + }) + + Convey("returns false with nil Sync", func() { + extensionConfig := &config.ExtensionConfig{} + So(extensionConfig.IsStreamingEnabled(), ShouldBeFalse) + }) + + Convey("returns false when Sync has no registries", func() { + extensionConfig := &config.ExtensionConfig{ + Sync: &sync.Config{}, + } + So(extensionConfig.IsStreamingEnabled(), ShouldBeFalse) + }) + + Convey("returns false when no registry has streaming enabled", func() { + So(buildSyncConfigWithStreaming(false).IsStreamingEnabled(), ShouldBeFalse) + }) + + Convey("returns true when a registry has streaming enabled", func() { + So(buildSyncConfigWithStreaming(true).IsStreamingEnabled(), ShouldBeTrue) + }) + + Convey("returns true when only one of multiple registries has streaming enabled", func() { + streamEnabled := true + streamDisabled := false + extensionConfig := &config.ExtensionConfig{ + Sync: &sync.Config{ + Registries: []sync.RegistryConfig{ + {URLs: []string{"localhost:5000"}, Stream: &streamDisabled}, + {URLs: []string{"localhost:5001"}, Stream: &streamEnabled}, + }, + }, + } + So(extensionConfig.IsStreamingEnabled(), ShouldBeTrue) + }) + + Convey("returns false when all registries have streaming disabled", func() { + streamDisabled := false + extensionConfig := &config.ExtensionConfig{ + Sync: &sync.Config{ + Registries: []sync.RegistryConfig{ + {URLs: []string{"localhost:5000"}, Stream: &streamDisabled}, + {URLs: []string{"localhost:5001"}, Stream: &streamDisabled}, + }, + }, + } + So(extensionConfig.IsStreamingEnabled(), ShouldBeFalse) + }) + }) + Convey("Test IsScrubEnabled()", func() { testMethodWithNilConfig((*config.ExtensionConfig).IsScrubEnabled) testMethodWithNilSubConfig("Scrub", (*config.ExtensionConfig).IsScrubEnabled) diff --git a/pkg/extensions/config/sync/config_test.go b/pkg/extensions/config/sync/config_test.go index 0d44875e..34c53994 100644 --- a/pkg/extensions/config/sync/config_test.go +++ b/pkg/extensions/config/sync/config_test.go @@ -8,6 +8,28 @@ import ( syncconf "zotregistry.dev/zot/v2/pkg/extensions/config/sync" ) +func TestRegistryConfig_IsStreamEnabled(t *testing.T) { + Convey("IsStreamEnabled", t, func() { + Convey("returns false when Stream is nil (default)", func() { + cfg := syncconf.RegistryConfig{} + So(cfg.Stream, ShouldBeNil) + So(cfg.IsStreamEnabled(), ShouldBeFalse) + }) + + Convey("returns true when Stream is true", func() { + v := true + cfg := syncconf.RegistryConfig{Stream: &v} + So(cfg.IsStreamEnabled(), ShouldBeTrue) + }) + + Convey("returns false when Stream is false", func() { + v := false + cfg := syncconf.RegistryConfig{Stream: &v} + So(cfg.IsStreamEnabled(), ShouldBeFalse) + }) + }) +} + func TestRegistryConfig_ShouldSyncLegacyCosignTags(t *testing.T) { Convey("ShouldSyncLegacyCosignTags", t, func() { Convey("returns true when SyncLegacyCosignTags is nil (default)", func() { diff --git a/pkg/extensions/extension_sync.go b/pkg/extensions/extension_sync.go index 289ca318..14b8e802 100644 --- a/pkg/extensions/extension_sync.go +++ b/pkg/extensions/extension_sync.go @@ -67,7 +67,7 @@ func EnableSyncExtension(config *config.Config, metaDB mTypes.MetaDB, // Only pass the stream manager to services that have streaming enabled on their registry config. var svcStreamManager sync.StreamManager - if registryConfig.Stream != nil && *registryConfig.Stream { + if registryConfig.IsStreamEnabled() { svcStreamManager = streamManager } diff --git a/pkg/extensions/extension_sync_test.go b/pkg/extensions/extension_sync_test.go new file mode 100644 index 00000000..686799b4 --- /dev/null +++ b/pkg/extensions/extension_sync_test.go @@ -0,0 +1,164 @@ +//go:build sync + +package extensions_test + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" + + "zotregistry.dev/zot/v2/pkg/api/config" + "zotregistry.dev/zot/v2/pkg/extensions" + extconf "zotregistry.dev/zot/v2/pkg/extensions/config" + syncconf "zotregistry.dev/zot/v2/pkg/extensions/config/sync" + "zotregistry.dev/zot/v2/pkg/extensions/monitoring" + "zotregistry.dev/zot/v2/pkg/log" + "zotregistry.dev/zot/v2/pkg/scheduler" + "zotregistry.dev/zot/v2/pkg/storage" + "zotregistry.dev/zot/v2/pkg/test/mocks" +) + +func TestEnableSyncExtension_StreamManager(t *testing.T) { + Convey("EnableSyncExtension stream manager setup", t, func() { + logger := log.NewTestLogger() + cfg := config.New() + cfg.Storage.RootDirectory = t.TempDir() + + metaDB := mocks.MetaDBMock{} + storeController := storage.StoreController{} + metrics := monitoring.NewMetricsServer(false, logger) + sch := scheduler.NewScheduler(cfg, metrics, logger) + + Convey("stream manager is nil when Stream is not set on any registry", func() { + cfg.Extensions = &extconf.ExtensionConfig{ + Sync: &syncconf.Config{ + Registries: []syncconf.RegistryConfig{ + { + URLs: []string{"http://localhost:5000"}, + OnDemand: true, + }, + }, + }, + } + + onDemand, err := extensions.EnableSyncExtension(cfg, metaDB, storeController, sch, logger) + So(err, ShouldBeNil) + So(onDemand, ShouldNotBeNil) + So(onDemand.StreamManager(), ShouldBeNil) + }) + + Convey("stream manager is nil when streaming is explicitly disabled on all registries", func() { + streamDisabled := false + + cfg.Extensions = &extconf.ExtensionConfig{ + Sync: &syncconf.Config{ + Registries: []syncconf.RegistryConfig{ + { + URLs: []string{"http://localhost:5000"}, + OnDemand: true, + Stream: &streamDisabled, + }, + }, + }, + } + + onDemand, err := extensions.EnableSyncExtension(cfg, metaDB, storeController, sch, logger) + So(err, ShouldBeNil) + So(onDemand, ShouldNotBeNil) + So(onDemand.StreamManager(), ShouldBeNil) + }) + + Convey("stream manager is set when a registry has streaming enabled", func() { + streamEnabled := true + + cfg.Extensions = &extconf.ExtensionConfig{ + Sync: &syncconf.Config{ + Registries: []syncconf.RegistryConfig{ + { + URLs: []string{"http://localhost:5000"}, + OnDemand: true, + Stream: &streamEnabled, + }, + }, + }, + } + + onDemand, err := extensions.EnableSyncExtension(cfg, metaDB, storeController, sch, logger) + So(err, ShouldBeNil) + So(onDemand, ShouldNotBeNil) + So(onDemand.StreamManager(), ShouldNotBeNil) + }) + + Convey("stream manager is set when only one of multiple registries has streaming enabled", func() { + streamEnabled := true + streamDisabled := false + + cfg.Extensions = &extconf.ExtensionConfig{ + Sync: &syncconf.Config{ + Registries: []syncconf.RegistryConfig{ + { + URLs: []string{"http://localhost:5000"}, + OnDemand: true, + Stream: &streamDisabled, + }, + { + URLs: []string{"http://localhost:5001"}, + OnDemand: true, + Stream: &streamEnabled, + }, + }, + }, + } + + onDemand, err := extensions.EnableSyncExtension(cfg, metaDB, storeController, sch, logger) + So(err, ShouldBeNil) + So(onDemand, ShouldNotBeNil) + So(onDemand.StreamManager(), ShouldNotBeNil) + }) + + Convey("stream manager is set with mix of polling and on-demand with streaming enabled", func() { + streamEnabled := true + + cfg.Extensions = &extconf.ExtensionConfig{ + Sync: &syncconf.Config{ + Registries: []syncconf.RegistryConfig{ + { + URLs: []string{"http://localhost:5000"}, + PollInterval: 60, + }, + { + URLs: []string{"http://localhost:5001"}, + OnDemand: true, + Stream: &streamEnabled, + }, + }, + }, + } + + onDemand, err := extensions.EnableSyncExtension(cfg, metaDB, storeController, sch, logger) + So(err, ShouldBeNil) + So(onDemand, ShouldNotBeNil) + So(onDemand.StreamManager(), ShouldNotBeNil) + }) + + Convey("returns nil onDemand when sync is disabled", func() { + syncDisabled := false + + cfg.Extensions = &extconf.ExtensionConfig{ + Sync: &syncconf.Config{ + Enable: &syncDisabled, + Registries: []syncconf.RegistryConfig{ + { + URLs: []string{"http://localhost:5000"}, + OnDemand: true, + }, + }, + }, + } + + onDemand, err := extensions.EnableSyncExtension(cfg, metaDB, storeController, sch, logger) + So(err, ShouldBeNil) + So(onDemand, ShouldBeNil) + }) + }) +} diff --git a/pkg/extensions/sync/chunked_blob_reader.go b/pkg/extensions/sync/chunked_blob_reader.go index 991c9fa8..c40a88c0 100644 --- a/pkg/extensions/sync/chunked_blob_reader.go +++ b/pkg/extensions/sync/chunked_blob_reader.go @@ -103,6 +103,11 @@ func (cbr *ChunkedBlobReader) Read(buff []byte) (int, error) { if clsErr != nil { cbr.logger.Error().Err(clsErr).Msg("failed to close on disk file") } + // All bytes have been written to disk; treat as EOF regardless of + // what io.ReadFull returned. This handles the case where the caller's + // buffer is exactly the remaining data size and io.ReadFull returns + // (n, nil) instead of (n, io.ErrUnexpectedEOF). + err = io.EOF } numBytesRead := cbr.numBytesReadToDisk @@ -135,7 +140,7 @@ func (cbr *ChunkedBlobReader) Read(buff []byte) (int, error) { } // Subscribe to the reader each time a new client is interested in the current blob, -// the client would create a subscription here with a channel where latest chunk info is sent. +// the client would create a subscription here with a channel where latest bytes info is sent. func (cbr *ChunkedBlobReader) Subscribe() (chan int64, int) { cbr.clientMu.Lock() defer func() { @@ -151,11 +156,12 @@ func (cbr *ChunkedBlobReader) Subscribe() (chan int64, int) { cbr.bytesMu.RLock() defer cbr.bytesMu.RUnlock() - // Announce the current number of available chunks to the new client only if the reader is initialized + // Announce the current number of available bytes to the new client only if + // the reader is initialized. Send synchronously while clientMu is held so + // that Unsubscribe cannot close the channel between the map insertion above + // and this send. if cbr.InFlightReader != nil { - go func() { - channel <- cbr.numBytesReadToDisk - }() + channel <- cbr.numBytesReadToDisk } return channel, chanId diff --git a/pkg/extensions/sync/chunked_blob_reader_internal_test.go b/pkg/extensions/sync/chunked_blob_reader_internal_test.go new file mode 100644 index 00000000..aba487a1 --- /dev/null +++ b/pkg/extensions/sync/chunked_blob_reader_internal_test.go @@ -0,0 +1,400 @@ +//go:build sync + +package sync + +import ( + "bytes" + "io" + "os" + "path/filepath" + "sync" + "testing" + + godigest "github.com/opencontainers/go-digest" + "github.com/regclient/regclient/types/blob" + "github.com/regclient/regclient/types/descriptor" + . "github.com/smartystreets/goconvey/convey" + + zerr "zotregistry.dev/zot/v2/errors" + "zotregistry.dev/zot/v2/pkg/log" +) + +func newTestBReader(data []byte) *blob.BReader { + dig := godigest.FromBytes(data) + + return blob.NewReader( + blob.WithDesc(descriptor.Descriptor{ + Digest: dig, + Size: int64(len(data)), + MediaType: "application/octet-stream", + }), + blob.WithReader(bytes.NewReader(data)), + ) +} + +func TestNewChunkedBlobReader(t *testing.T) { + Convey("NewChunkedBlobReader", t, func() { + Convey("creates file and returns reader on valid path", func() { + dir := t.TempDir() + path := filepath.Join(dir, "blob.bin") + + cbr, err := NewChunkedBlobReader(path, log.NewTestLogger()) + So(err, ShouldBeNil) + So(cbr, ShouldNotBeNil) + So(cbr.onDiskPath, ShouldEqual, path) + So(cbr.onDiskFile, ShouldNotBeNil) + + // File should exist on disk + _, statErr := os.Stat(path) + So(statErr, ShouldBeNil) + + cbr.onDiskFile.Close() + }) + + Convey("returns error on invalid path", func() { + cbr, err := NewChunkedBlobReader("/nonexistent/dir/blob.bin", log.NewTestLogger()) + So(err, ShouldNotBeNil) + So(cbr, ShouldBeNil) + }) + }) +} + +func TestInitReader(t *testing.T) { + Convey("InitReader", t, func() { + dir := t.TempDir() + cbr, err := NewChunkedBlobReader(filepath.Join(dir, "blob.bin"), log.NewTestLogger()) + So(err, ShouldBeNil) + + data := []byte("hello world") + reader := newTestBReader(data) + + Convey("sets the in-flight reader and total bytes", func() { + So(cbr.InFlightReader, ShouldBeNil) + + cbr.InitReader(reader, int64(len(data))) + + So(cbr.InFlightReader, ShouldEqual, reader) + So(cbr.numBytesTotal, ShouldEqual, int64(len(data))) + }) + + Convey("is idempotent — second call does not overwrite first reader", func() { + cbr.InitReader(reader, int64(len(data))) + + secondReader := newTestBReader([]byte("other data")) + cbr.InitReader(secondReader, 99) + + So(cbr.InFlightReader, ShouldEqual, reader) + So(cbr.numBytesTotal, ShouldEqual, int64(len(data))) + }) + }) +} + +func TestRead(t *testing.T) { + Convey("Read", t, func() { + dir := t.TempDir() + blobPath := filepath.Join(dir, "blob.bin") + cbr, err := NewChunkedBlobReader(blobPath, log.NewTestLogger()) + So(err, ShouldBeNil) + + data := []byte("hello world") + cbr.InitReader(newTestBReader(data), int64(len(data))) + + Convey("reads all data and writes it to disk", func() { + buf := make([]byte, len(data)) + n, err := cbr.Read(buf) + // When the buffer is exactly the data size, all bytes are consumed in + // one call; Read detects numBytesReadToDisk == numBytesTotal and + // returns io.EOF to signal completion. + So(err, ShouldEqual, io.EOF) + So(n, ShouldEqual, len(data)) + So(buf[:n], ShouldResemble, data) + + // File should contain the data written so far + onDisk, readErr := os.ReadFile(blobPath) + So(readErr, ShouldBeNil) + So(onDisk, ShouldResemble, data) + }) + + Convey("partial read at end of stream preserves all bytes", func() { + // Read the first 5 bytes with an exact-fit buffer → (5, nil). + firstBuf := make([]byte, 5) + numBytesRead1, err1 := cbr.Read(firstBuf) + So(err1, ShouldBeNil) + So(numBytesRead1, ShouldEqual, 5) + + // Read the remaining 6 bytes with a buffer of 10: io.ReadFull can + // only fill 6 bytes before hitting EOF and returns (6, ErrUnexpectedEOF). + // Read normalises that to (6, io.EOF) at line 87. + secondBuf := make([]byte, 10) + numBytesRead2, err2 := cbr.Read(secondBuf) + So(err2, ShouldEqual, io.EOF) + So(numBytesRead2, ShouldEqual, 6) + + // Reconstruct what was read in memory and compare to source. + So(append(firstBuf[:numBytesRead1], secondBuf[:numBytesRead2]...), ShouldResemble, data) + + // On-disk file must contain every byte — none dropped. + onDisk, readErr := os.ReadFile(blobPath) + So(readErr, ShouldBeNil) + So(onDisk, ShouldResemble, data) + }) + + Convey("increments numBytesReadToDisk correctly", func() { + chunk := make([]byte, 5) + n, readErr := cbr.Read(chunk) + So(readErr, ShouldBeNil) + So(n, ShouldEqual, 5) + + cbr.bytesMu.RLock() + bytesRead := cbr.numBytesReadToDisk + cbr.bytesMu.RUnlock() + + So(bytesRead, ShouldEqual, 5) + }) + + Convey("notifies subscribed clients with latest byte offset", func() { + ch, id := cbr.Subscribe() + defer cbr.Unsubscribe(id) + + buf := make([]byte, len(data)) + done := make(chan struct{}) + + go func() { + _, _ = cbr.Read(buf) + close(done) + }() + + // Consume client channel; closed automatically on EOF. + var lastOffset int64 + for offset := range ch { + lastOffset = offset + } + + <-done + + So(lastOffset, ShouldEqual, int64(len(data))) + }) + + Convey("closes all clients when EOF is reached", func() { + bytesUpdateChan, _ := cbr.Subscribe() + + buf := make([]byte, len(data)) + + var wg sync.WaitGroup + wg.Go(func() { + _, _ = cbr.Read(buf) + }) + + // Drain the channel - it should be closed after the full read. + for range bytesUpdateChan { + } + + wg.Wait() + + cbr.clientMu.RLock() + numClients := len(cbr.clients) + cbr.clientMu.RUnlock() + + So(numClients, ShouldEqual, 0) + }) + + Convey("returns error and closes clients on upstream read error", func() { + errDir := t.TempDir() + errPath := filepath.Join(errDir, "blob.bin") + errCBR, cerr := NewChunkedBlobReader(errPath, log.NewTestLogger()) + So(cerr, ShouldBeNil) + + // Subscribe before InitReader: InFlightReader is nil so no initial + // value is placed in the channel. Subscribing after InitReader would + // buffer a 0 in the channel (the current byte offset), causing the + // first receive below to return (0, true) instead of (0, false). + bytesUpdateChan, _ := errCBR.Subscribe() + + errReader := blob.NewReader( + blob.WithDesc(descriptor.Descriptor{ + Digest: godigest.FromBytes([]byte("x")), + Size: 100, // larger than actual data to force a non-EOF error + MediaType: "application/octet-stream", + }), + blob.WithReader(errReaderFunc(func(p []byte) (int, error) { + return 0, zerr.ErrSyncUpstreamDownloadFailed + })), + ) + errCBR.InitReader(errReader, 100) + + buf := make([]byte, 50) + n, readErr := errCBR.Read(buf) + So(readErr, ShouldNotBeNil) + So(n, ShouldEqual, -1) + + // Channel should have been closed. + _, open := <-bytesUpdateChan + So(open, ShouldBeFalse) + }) + }) +} + +func TestSubscribeUnsubscribe(t *testing.T) { + Convey("Subscribe and Unsubscribe", t, func() { + dir := t.TempDir() + cbr, err := NewChunkedBlobReader(filepath.Join(dir, "blob.bin"), log.NewTestLogger()) + So(err, ShouldBeNil) + defer cbr.onDiskFile.Close() + + Convey("Subscribe returns a channel and a unique client ID", func() { + ch1, id1 := cbr.Subscribe() + ch2, id2 := cbr.Subscribe() + + So(ch1, ShouldNotBeNil) + So(ch2, ShouldNotBeNil) + So(id1, ShouldNotEqual, id2) + + cbr.Unsubscribe(id1) + cbr.Unsubscribe(id2) + }) + + Convey("Subscribe sends current byte offset when reader is already initialized", func() { + data := []byte("preloaded") + cbr.InitReader(newTestBReader(data), int64(len(data))) + + // Manually advance numBytesReadToDisk to simulate partial read. + cbr.bytesMu.Lock() + cbr.numBytesReadToDisk = 5 + cbr.bytesMu.Unlock() + + ch, id := cbr.Subscribe() + defer cbr.Unsubscribe(id) + + offset := <-ch + So(offset, ShouldEqual, int64(5)) + }) + + Convey("Subscribe does not send initial offset when reader is not yet initialized", func() { + ch, id := cbr.Subscribe() + defer cbr.Unsubscribe(id) + + // Channel should be empty since reader is not initialized. + So(len(ch), ShouldEqual, 0) + }) + + Convey("Unsubscribe closes the channel and removes the client", func() { + ch, clientId := cbr.Subscribe() + cbr.Unsubscribe(clientId) + + _, open := <-ch + So(open, ShouldBeFalse) + + cbr.clientMu.RLock() + _, exists := cbr.clients[clientId] + cbr.clientMu.RUnlock() + + So(exists, ShouldBeFalse) + }) + + Convey("Unsubscribe is a no-op for unknown client ID", func() { + So(func() { cbr.Unsubscribe(9999) }, ShouldNotPanic) + }) + }) +} + +func TestWaitForClientEmpty(t *testing.T) { + Convey("WaitForClientEmpty", t, func() { + dir := t.TempDir() + cbr, err := NewChunkedBlobReader(filepath.Join(dir, "blob.bin"), log.NewTestLogger()) + So(err, ShouldBeNil) + defer cbr.onDiskFile.Close() + + Convey("returns immediately when there are no clients", func() { + done := make(chan struct{}) + + go func() { + cbr.WaitForClientEmpty() + close(done) + }() + + <-done // should not block + }) + + Convey("blocks until all clients unsubscribe", func() { + _, id := cbr.Subscribe() + + done := make(chan struct{}) + + go func() { + cbr.WaitForClientEmpty() + close(done) + }() + + // Verify it's blocking. + select { + case <-done: + So("WaitForClientEmpty returned before client unsubscribed", ShouldBeEmpty) + default: + // expected: still waiting + } + + cbr.Unsubscribe(id) + <-done + }) + + Convey("blocks while multiple clients are subscribed and wakes on each unsubscribe", func() { + _, id1 := cbr.Subscribe() + _, id2 := cbr.Subscribe() + _, id3 := cbr.Subscribe() + + done := make(chan struct{}) + + go func() { + cbr.WaitForClientEmpty() + close(done) + }() + + // Still blocking with three clients present. + select { + case <-done: + So("WaitForClientEmpty returned before all clients unsubscribed", ShouldBeEmpty) + default: + } + + // Unsubscribe one at a time. WaitForClientEmpty must not return + // until the last client is gone. + cbr.Unsubscribe(id1) + cbr.Unsubscribe(id2) + + select { + case <-done: + So("WaitForClientEmpty returned with one client still subscribed", ShouldBeEmpty) + default: + } + + cbr.Unsubscribe(id3) + <-done + }) + }) +} + +func TestToBReader(t *testing.T) { + Convey("ToBReader", t, func() { + dir := t.TempDir() + cbr, err := NewChunkedBlobReader(filepath.Join(dir, "blob.bin"), log.NewTestLogger()) + So(err, ShouldBeNil) + defer cbr.onDiskFile.Close() + + data := []byte("to-breader test data") + original := newTestBReader(data) + cbr.InitReader(original, int64(len(data))) + + br := cbr.ToBReader() + So(br, ShouldNotBeNil) + + // The returned BReader should have the same descriptor as the original. + So(br.GetDescriptor().Digest, ShouldEqual, original.GetDescriptor().Digest) + So(br.GetDescriptor().Size, ShouldEqual, original.GetDescriptor().Size) + }) +} + +type errReaderFunc func(p []byte) (int, error) + +func (f errReaderFunc) Read(p []byte) (int, error) { + return f(p) +} diff --git a/pkg/extensions/sync/inflight_blob_copier_internal_test.go b/pkg/extensions/sync/inflight_blob_copier_internal_test.go new file mode 100644 index 00000000..9c9e605c --- /dev/null +++ b/pkg/extensions/sync/inflight_blob_copier_internal_test.go @@ -0,0 +1,148 @@ +//go:build sync + +package sync + +import ( + "bytes" + "io" + "path/filepath" + "testing" + + godigest "github.com/opencontainers/go-digest" + "github.com/regclient/regclient/types/blob" + "github.com/regclient/regclient/types/descriptor" + . "github.com/smartystreets/goconvey/convey" + + zerr "zotregistry.dev/zot/v2/errors" + "zotregistry.dev/zot/v2/pkg/log" +) + +func TestInFlightBlobCopierCopy(t *testing.T) { + Convey("InFlightBlobCopier.Copy", t, func() { + Convey("copies entire blob to destination", func() { + dir := t.TempDir() + blobPath := filepath.Join(dir, "blob.bin") + data := []byte("hello inflight world") + + cbr, err := NewChunkedBlobReader(blobPath, log.NewTestLogger()) + So(err, ShouldBeNil) + cbr.InitReader(newTestBReader(data), int64(len(data))) + + var dest bytes.Buffer + ifbc := NewInFlightBlobCopier(cbr, blobPath, &dest, log.NewTestLogger()) + + // Run the read concurrently. Copy() blocks until it receives the + // final byte-offset notification or sees the file data via a late subscribe. + done := make(chan struct{}) + go func() { + buf := make([]byte, len(data)) + _, _ = cbr.Read(buf) + close(done) + }() + + copyErr := ifbc.Copy() + So(copyErr, ShouldBeNil) + So(dest.Bytes(), ShouldResemble, data) + <-done + }) + + Convey("copies blob delivered in multiple chunks", func() { + dir := t.TempDir() + blobPath := filepath.Join(dir, "blob.bin") + data := []byte("hello inflight world") + const firstChunk = 8 + + cbr, err := NewChunkedBlobReader(blobPath, log.NewTestLogger()) + So(err, ShouldBeNil) + cbr.InitReader(newTestBReader(data), int64(len(data))) + + var dest bytes.Buffer + ifbc := NewInFlightBlobCopier(cbr, blobPath, &dest, log.NewTestLogger()) + + copyResult := make(chan error, 1) + go func() { + copyResult <- ifbc.Copy() + }() + + // Wait until Copy() has subscribed so it sees each chunk notification + // individually rather than only the final byte count. + cbr.clientMu.Lock() + for len(cbr.clients) == 0 { + cbr.clientCond.Wait() + } + cbr.clientMu.Unlock() + + // First chunk: exactly firstChunk bytes — returns (firstChunk, nil). + buf1 := make([]byte, firstChunk) + n1, readErr1 := cbr.Read(buf1) + So(readErr1, ShouldBeNil) + So(n1, ShouldEqual, firstChunk) + + // Second chunk: remainder — exact-size buffer triggers the + // numBytesReadToDisk >= numBytesTotal check which returns io.EOF. + buf2 := make([]byte, len(data)-firstChunk) + n2, readErr2 := cbr.Read(buf2) + So(readErr2, ShouldEqual, io.EOF) + So(n2, ShouldEqual, len(data)-firstChunk) + + So(<-copyResult, ShouldBeNil) + So(dest.Bytes(), ShouldResemble, data) + }) + + Convey("returns error when on-disk file cannot be opened", func() { + dir := t.TempDir() + cbr, err := NewChunkedBlobReader(filepath.Join(dir, "blob.bin"), log.NewTestLogger()) + So(err, ShouldBeNil) + defer cbr.onDiskFile.Close() + + var dest bytes.Buffer + ifbc := NewInFlightBlobCopier(cbr, "/nonexistent/path/blob.bin", &dest, log.NewTestLogger()) + + copyErr := ifbc.Copy() + So(copyErr, ShouldNotBeNil) + }) + + Convey("returns ErrSyncUpstreamDownloadFailed when upstream download fails", func() { + errDir := t.TempDir() + errPath := filepath.Join(errDir, "blob.bin") + errCBR, cerr := NewChunkedBlobReader(errPath, log.NewTestLogger()) + So(cerr, ShouldBeNil) + + errCBR.InitReader(blob.NewReader( + blob.WithDesc(descriptor.Descriptor{ + Digest: godigest.FromBytes([]byte("x")), + Size: 100, + MediaType: "application/octet-stream", + }), + blob.WithReader(errReaderFunc(func(p []byte) (int, error) { + return 0, zerr.ErrSyncUpstreamDownloadFailed + })), + ), 100) + + var dest bytes.Buffer + ifbc := NewInFlightBlobCopier(errCBR, errPath, &dest, log.NewTestLogger()) + + copyResult := make(chan error, 1) + go func() { + copyResult <- ifbc.Copy() + }() + + // Wait until Copy() has subscribed so that the Read() error below is + // guaranteed to close Copy's channel. + // Whether Copy() has already consumed the initial 0 from Subscribe or + // it is still buffered, the channel close returns (0, false) which + // causes Copy() to return ErrSyncUpstreamDownloadFailed. + errCBR.clientMu.Lock() + for len(errCBR.clients) == 0 { + errCBR.clientCond.Wait() + } + errCBR.clientMu.Unlock() + + // Trigger the upstream error; Read() closes all subscriber channels. + buf := make([]byte, 50) + _, _ = errCBR.Read(buf) + + So(<-copyResult, ShouldEqual, zerr.ErrSyncUpstreamDownloadFailed) + }) + }) +}