Files
zot/pkg/test/common/fs.go
T
Andrei Aaron c6289ec5ba fix: address code review comments (#3942)
* fix: address code review comments in https://github.com/project-zot/zot/pull/3885#pullrequestreview-4045836197

Signed-off-by: Andrei Aaron <andreifdaaron@gmail.com>

* fix: data race in GetPort()

See https://github.com/project-zot/zot/actions/runs/24045271222/job/70126983674?pr=3942

Signed-off-by: Andrei Aaron <andreifdaaron@gmail.com>

* fix(test): reuse ReadLogFileAndSearchString for auto-port log; throttle poll loop

Signed-off-by: Andrei Aaron <andreifdaaron@gmail.com>

---------

Signed-off-by: Andrei Aaron <andreifdaaron@gmail.com>
2026-04-08 00:10:54 +03:00

378 lines
7.7 KiB
Go

package common
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"strings"
"testing"
"time"
"github.com/GehirnInc/crypt"
_ "github.com/GehirnInc/crypt/sha256_crypt"
_ "github.com/GehirnInc/crypt/sha512_crypt"
"golang.org/x/crypto/bcrypt"
)
var ErrNoGoModFileFound = errors.New("no go.mod file found in parent directories")
func GetProjectRootDir() (string, error) {
workDir, err := os.Getwd()
if err != nil {
return "", err
}
for {
goModPath := filepath.Join(workDir, "go.mod")
_, err := os.Stat(goModPath)
if err == nil {
return workDir, nil
}
if workDir == filepath.Dir(workDir) {
return "", ErrNoGoModFileFound
}
workDir = filepath.Dir(workDir)
}
}
func CopyFile(sourceFilePath, destFilePath string) error {
destFile, err := os.Create(destFilePath)
if err != nil {
return err
}
defer destFile.Close()
sourceFile, err := os.Open(sourceFilePath)
if err != nil {
return err
}
defer sourceFile.Close()
if _, err = io.Copy(destFile, sourceFile); err != nil {
return err
}
return nil
}
func CopyFiles(sourceDir, destDir string) error {
sourceMeta, err := os.Stat(sourceDir)
if err != nil {
return fmt.Errorf("CopyFiles os.Stat failed: %w", err)
}
if err := os.MkdirAll(destDir, sourceMeta.Mode()); err != nil {
return fmt.Errorf("CopyFiles os.MkdirAll failed: %w", err)
}
files, err := os.ReadDir(sourceDir)
if err != nil {
return fmt.Errorf("CopyFiles os.ReadDir failed: %w", err)
}
for _, file := range files {
sourceFilePath := path.Join(sourceDir, file.Name())
destFilePath := path.Join(destDir, file.Name())
if file.IsDir() {
if strings.HasPrefix(file.Name(), "_") {
// Some tests create the trivy related folders under test/_trivy
continue
}
if err = CopyFiles(sourceFilePath, destFilePath); err != nil {
return err
}
} else {
sourceFile, err := os.Open(sourceFilePath)
if err != nil {
return fmt.Errorf("CopyFiles os.Open failed: %w", err)
}
defer sourceFile.Close()
destFile, err := os.Create(destFilePath)
if err != nil {
return fmt.Errorf("CopyFiles os.Create failed: %w", err)
}
defer destFile.Close()
if _, err = io.Copy(destFile, sourceFile); err != nil {
return fmt.Errorf("io.Copy failed: %w", err)
}
}
}
return nil
}
func CopyTestKeysAndCerts(destDir string) error {
files := []string{
"ca.crt", "ca.key", "client.cert", "client.csr",
"client.key", "server.cert", "server.csr", "server.key",
}
rootPath, err := GetProjectRootDir()
if err != nil {
return err
}
sourceDir := filepath.Join(rootPath, "test/data")
sourceMeta, err := os.Stat(sourceDir)
if err != nil {
return fmt.Errorf("CopyFiles os.Stat failed: %w", err)
}
if err := os.MkdirAll(destDir, sourceMeta.Mode()); err != nil {
return err
}
for _, file := range files {
err = CopyFile(filepath.Join(sourceDir, file), filepath.Join(destDir, file))
if err != nil {
return err
}
}
return nil
}
func WriteFileWithPermission(path string, data []byte, perm fs.FileMode, overwrite bool) error {
if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
return err
}
flag := os.O_WRONLY | os.O_CREATE
if overwrite {
flag |= os.O_TRUNC
} else {
flag |= os.O_EXCL
}
file, err := os.OpenFile(path, flag, perm)
if err != nil {
return err
}
_, err = file.Write(data)
if err != nil {
file.Close()
return err
}
return file.Close()
}
func ReadLogFileAndSearchString(logPath string, stringToMatch string, timeout time.Duration) (bool, error) {
ctx, cancelFunc := context.WithTimeout(context.Background(), timeout)
defer cancelFunc()
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, nil
default:
}
content, err := os.ReadFile(logPath)
if err != nil {
return false, err
}
if strings.Contains(string(content), stringToMatch) {
return true, nil
}
select {
case <-ctx.Done():
return false, nil
case <-ticker.C:
}
}
}
func ReadLogFileAndCountStringOccurence(logPath string, stringToMatch string,
timeout time.Duration, count int,
) (bool, error) {
ctx, cancelFunc := context.WithTimeout(context.Background(), timeout)
defer cancelFunc()
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, nil
default:
}
content, err := os.ReadFile(logPath)
if err != nil {
return false, err
}
if strings.Count(string(content), stringToMatch) >= count {
return true, nil
}
select {
case <-ctx.Done():
return false, nil
case <-ticker.C:
}
}
}
func GetBcryptCredString(username, password string) string {
hash, err := bcrypt.GenerateFromPassword([]byte(password), 10)
if err != nil {
panic(err)
}
usernameAndHash := fmt.Sprintf("%s:%s\n", username, string(hash))
return usernameAndHash
}
const (
PrefixCryptSha256 = "$5$"
PrefixCryptSha512 = "$6$"
Separator = "$"
)
// generateSecureRandomString should only be used in tests with length = 16.
func generateSecureRandomString(length int) (string, error) {
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
}
func shaCrypt(password string, rounds string, salt string, prefix string) string {
var ret string
var strb strings.Builder
strb.WriteString(prefix)
if len(rounds) > 0 {
strb.WriteString(rounds)
strb.WriteString(Separator)
}
strb.WriteString(salt)
totalSalt := strb.String()
var err error
switch prefix {
case PrefixCryptSha256:
crypter := crypt.SHA256.New()
ret, err = crypter.Generate([]byte(password), []byte(totalSalt))
case PrefixCryptSha512:
crypter := crypt.SHA512.New()
ret, err = crypter.Generate([]byte(password), []byte(totalSalt))
default:
panic("unsupported password hash")
}
if err != nil {
panic(err)
}
return ret
}
func GetSHA256CredString(username, password string) string {
saltstr, err := generateSecureRandomString(16)
if err != nil {
panic(err)
}
hash := shaCrypt(password, "rounds=5000", saltstr, PrefixCryptSha256)
usernameAndHash := fmt.Sprintf("%s:%s\n", username, hash)
return usernameAndHash
}
func GetSHA512CredString(username, password string) string {
saltstr, err := generateSecureRandomString(16)
if err != nil {
panic(err)
}
hash := shaCrypt(password, "rounds=5000", saltstr, PrefixCryptSha512)
usernameAndHash := fmt.Sprintf("%s:%s\n", username, hash)
return usernameAndHash
}
func MakeTempFile(tb testing.TB, filename string) *os.File {
tb.Helper()
tempDir := tb.TempDir()
file, err := os.Create(filepath.Join(tempDir, filename))
if err != nil {
panic(err)
}
return file
}
// MakeTempFilePath creates an empty temporary file and returns its path.
func MakeTempFilePath(tb testing.TB, filename string) string {
tb.Helper()
return filepath.Join(tb.TempDir(), filename)
}
// MakeTempFileWithContent creates a temporary file with the given filename and content, and returns its path.
func MakeTempFileWithContent(tb testing.TB, filename, content string) string {
tb.Helper()
tmpfile := MakeTempFile(tb, filename)
path := tmpfile.Name()
tmpfile.Close() // Close immediately, we'll write using os.WriteFile
if err := os.WriteFile(path, []byte(content), 0o0600); err != nil {
panic(err)
}
return path
}
func MakeHtpasswdFileFromString(tb testing.TB, fileContent string) string {
tb.Helper()
tempDir := tb.TempDir()
htpasswdFile, err := os.Create(filepath.Join(tempDir, "htpasswd"))
if err != nil {
panic(err)
}
content := []byte(fileContent)
if err := os.WriteFile(htpasswdFile.Name(), content, 0o600); err != nil { //nolint:mnd
panic(err)
}
return htpasswdFile.Name()
}