Skip to content

Commit

Permalink
feat: streaming decompression for presigned download (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
yetone authored Dec 4, 2024
1 parent d0708d9 commit 7a3271c
Showing 1 changed file with 229 additions and 80 deletions.
309 changes: 229 additions & 80 deletions bento-image-snapshotter/fs/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (
"os/exec"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/service/s3"
Expand Down Expand Up @@ -75,6 +77,13 @@ func (o S3FileSystem) getObjectSize(ctx context.Context, bucketName, objectKey s
}

func (o S3FileSystem) Mount(ctx context.Context, mountpoint string, labels map[string]string) error {
if mounted, err := o.isMounted(mountpoint); err != nil {
return errors.Wrap(err, "failed to check if mountpoint is mounted")
} else if mounted {
log.G(ctx).WithField("mountpoint", mountpoint).Info("already mounted")
return nil
}

bucketName := labels[common.DescriptorAnnotationBucket]
objectKey := labels[common.DescriptorAnnotationObjectKey]
logger := log.G(ctx).WithField("bucket", bucketName).WithField("key", objectKey).WithField("mountpoint", mountpoint)
Expand Down Expand Up @@ -126,7 +135,7 @@ func (o S3FileSystem) Mount(ctx context.Context, mountpoint string, labels map[s
return errors.Wrap(err, "failed to symlink destinationPath to mountpoint")
}
}
return nil
return createMountedFlag(mountpoint)
}

func (o S3FileSystem) Check(ctx context.Context, mountpoint string, labels map[string]string) error {
Expand All @@ -150,6 +159,45 @@ func getRamfsPath(path string) string {
return path + ".ramfs"
}

func getMountedFlagPath(mountpoint string) string {
baseDir := filepath.Dir(mountpoint)
return filepath.Join(baseDir, ".mounted")
}

func createMountedFlag(mountpoint string) error {
mountedFlagPath := getMountedFlagPath(mountpoint)
file, err := os.Create(mountedFlagPath)
if err != nil {
return errors.Wrap(err, "failed to create mounted flag")
}
defer file.Close()
return nil
}

func removeMountedFlag(mountpoint string) error {
mountedFlagPath := getMountedFlagPath(mountpoint)
if err := os.Remove(mountedFlagPath); err != nil {
if !os.IsNotExist(err) {
return errors.Wrap(err, "failed to remove mounted flag")
}
}
return nil
}

func (o S3FileSystem) isMounted(mountpoint string) (bool, error) {
if o.EnableRamfs {
return isRamfsMounted(mountpoint)
}
mountedFlagPath := getMountedFlagPath(mountpoint)
if _, err := os.Stat(mountedFlagPath); err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, errors.Wrap(err, "failed to stat mounted flag")
}
return true, nil
}

func isRamfsMounted(path string) (bool, error) {
ramfsPath := getRamfsPath(path)
file, err := os.Open("/proc/mounts")
Expand Down Expand Up @@ -208,7 +256,7 @@ func (o S3FileSystem) Unmount(ctx context.Context, mountpoint string) error {
if err := os.RemoveAll(mountpoint); err != nil {
return errors.Wrap(err, "failed to remove mountpoint")
}
return nil
return removeMountedFlag(mountpoint)
}

func stringifyCmd(cmd *exec.Cmd) string {
Expand Down Expand Up @@ -253,71 +301,13 @@ func downloadRange(ctx context.Context, presignedURL string, start, end int64, w
return nil
}

func parrallelDownload(ctx context.Context, presignedURL string, destPath string, parallelism int) error {
// use curl to get the size of the file
logger := log.G(ctx).WithField("url", presignedURL)
logger.Debug("getting the size of the file...")
var stderr bytes.Buffer
cmd := exec.CommandContext(ctx, "curl", "--silent", "--show-error", "--fail", "--head", "--output", "-", presignedURL) // nolint:gosec
cmd.Stderr = &stderr
output, err := cmd.Output()
if err != nil {
return errors.Wrapf(err, "failed to run command: %s, stderr: %s", stringifyCmd(cmd), stderr.String())
}
var size int64
lines := strings.Split(string(output), "\n")
for _, line := range lines {
k, v, found := strings.Cut(line, ":")
if !found {
continue
}
if strings.ToLower(strings.TrimSpace(k)) == "content-length" {
var err error
size, err = strconv.ParseInt(strings.TrimSpace(v), 10, 64)
if err != nil {
return errors.Wrapf(err, "failed to parse size: %s", string(output))
}
logger.Debugf("size: %d", size)
break
}
}

if size == 0 {
return errors.New("cannot get the file size from presigned url")
}

func parrallelDownload(ctx context.Context, presignedURL string, writer io.WriterAt, size int64, parallelism int) error {
partSize := size / int64(parallelism)
if size%int64(parallelism) > 0 {
partSize += 1
}

baseDir := filepath.Dir(destPath)
if err := os.MkdirAll(baseDir, 0755); err != nil {
return errors.Wrap(err, "failed to create base dir")
}

// create the target file and falloc the target file
file, err := os.Create(destPath)
if err != nil {
return errors.Wrap(err, "failed to create target file")
}
defer file.Close()

// use truncate to fallocate the target file
// seek to the first byte of the file
if _, err := file.Seek(0, 0); err != nil {
return errors.Wrap(err, "failed to seek the target file")
}
// truncate the file to the size
if err := file.Truncate(size); err != nil {
return errors.Wrap(err, "failed to truncate the target file")
}
//
// if err := syscall.Fallocate(int(file.Fd()), 0, 0, size); err != nil {
// return errors.Wrap(err, "failed to fallocate the target file")
// }

var eg errgroup.Group
eg, ctx := errgroup.WithContext(ctx)
for i := 0; i < parallelism; i++ {
start := int64(i) * partSize
end := start + partSize - 1
Expand All @@ -332,10 +322,127 @@ func parrallelDownload(ctx context.Context, presignedURL string, destPath string
}
}
}()
return downloadRange(ctx, presignedURL, start, end, file)
err := downloadRange(ctx, presignedURL, start, end, writer)
if err != nil {
return errors.Wrap(err, "failed to download range")
}
return nil
})
}
return errors.Wrap(eg.Wait(), "downloading")

return errors.Wrap(eg.Wait(), "failed to download layer from S3")
}

type DataRange struct {
Start int64 // inclusive
End int64 // inclusive
}

type WriterAtReader struct {
file *os.File
readyDataRange []DataRange
eofReached bool
lock sync.RWMutex
cond *sync.Cond
curOffset int64
size int64
}

func NewWriterAtReader(file *os.File, size int64) *WriterAtReader {
w := &WriterAtReader{
file: file,
readyDataRange: make([]DataRange, 0),
size: size,
}
w.cond = sync.NewCond(&w.lock)
return w
}

func (w *WriterAtReader) WriteAt(p []byte, off int64) (n int, err error) {
n, err = w.file.WriteAt(p, off)
if err != nil {
return n, errors.Wrap(err, "failed to write data to file")
}

w.lock.Lock()
readyDataRange := append(w.readyDataRange, DataRange{Start: off, End: off + int64(n) - 1}) //nolint:gocritic
slices.SortFunc(readyDataRange, func(a, b DataRange) int {
return int(a.Start - b.Start)
})

// Merge overlapping or adjacent ranges
merged := make([]DataRange, 0)
if len(readyDataRange) > 0 {
current := readyDataRange[0]
for i := 1; i < len(readyDataRange); i++ {
if current.End+1 >= readyDataRange[i].Start {
if readyDataRange[i].End > current.End {
current.End = readyDataRange[i].End
}
} else {
merged = append(merged, current)
current = readyDataRange[i]
}
}
merged = append(merged, current)
}

w.readyDataRange = merged
w.eofReached = w.isEOF()
w.cond.Broadcast()
w.lock.Unlock()

return n, nil
}

func (w *WriterAtReader) isEOF() bool {
return len(w.readyDataRange) == 1 &&
w.readyDataRange[0].Start == 0 &&
w.readyDataRange[0].End == w.size-1
}

func (w *WriterAtReader) dataReady(offset, size int64) bool {
for _, r := range w.readyDataRange {
if r.Start <= offset && r.End >= offset+size-1 {
return true
}
}
return false
}

func (w *WriterAtReader) Read(p []byte) (int, error) {
w.lock.Lock()
for !w.eofReached && !w.dataReady(w.curOffset, int64(len(p))) {
w.cond.Wait()
}

if w.eofReached && w.curOffset >= w.size {
w.lock.Unlock()
return 0, io.EOF
}

readSize := int64(len(p))
if w.curOffset+readSize > w.size {
readSize = w.size - w.curOffset
}

w.lock.Unlock()

n, err := w.file.ReadAt(p[:readSize], w.curOffset)
w.curOffset += int64(n)

if err != nil && !errors.Is(err, io.EOF) {
return n, errors.Wrap(err, "failed to read data from file")
}
return n, nil
}

func (w *WriterAtReader) Close() error {
w.lock.Lock()
w.eofReached = true
w.cond.Broadcast()
w.lock.Unlock()
return nil
}

func (o *S3FileSystem) downloadLayerFromS3(ctx context.Context, bucketName, layerKey string, destinationPath string) error {
Expand Down Expand Up @@ -367,33 +474,75 @@ func (o *S3FileSystem) downloadLayerFromS3(ctx context.Context, bucketName, laye
return errors.Wrap(err, "failed to get s3 presigned url")
}

size, err := o.getObjectSize(ctx, bucketName, layerKey)
if err != nil {
return errors.Wrap(err, "failed to get object size")
}

coreNumber := runtime.NumCPU()

startTime := time.Now()
logger.Info("downloading layer from S3...")
err = parrallelDownload(ctx, presignedURL, downloadedFilePath, coreNumber)
baseDir := filepath.Dir(downloadedFilePath)
if err := os.MkdirAll(baseDir, 0755); err != nil {
return errors.Wrap(err, "failed to create base dir for downloaded file")
}

// create the target file and falloc the target file
file, err := os.Create(downloadedFilePath)
if err != nil {
return errors.Wrap(err, "failed to download layer from S3")
return errors.Wrap(err, "failed to create downloaded file")
}
defer file.Close()

downloadDuration := time.Since(startTime)
logger.WithField("duration", downloadDuration).Info("the layer has been downloaded from S3")
// use truncate to fallocate the target file
// seek to the first byte of the file
if _, err := file.Seek(0, 0); err != nil {
return errors.Wrap(err, "failed to seek the target file")
}
// truncate the file to the size
if err := file.Truncate(size); err != nil {
return errors.Wrap(err, "failed to truncate the target file")
}

startTime = time.Now()
logger.Info("decompressing layer...")
var stderr bytes.Buffer
cmd := exec.CommandContext(ctx, "sh", "-c", fmt.Sprintf("pzstd -d -c %s | tar -xf -", downloadedFilePath)) // nolint:gosec
cmd.Stderr = &stderr
cmd.Dir = tempName
w := NewWriterAtReader(file, size)

err = cmd.Run()
startTime := time.Now()
logger.Info("downloading and streaming decompression layer from S3...")
errCh := make(chan error, 1)

go func() {
defer func() {
_ = w.Close()
}()
err := parrallelDownload(ctx, presignedURL, w, size, coreNumber)
if err != nil {
errCh <- errors.Wrap(err, "failed to download layer from S3")
}
}()

go func() {
var stderr bytes.Buffer
cmd := exec.CommandContext(ctx, "sh", "-c", "pzstd -d | tar -xf -")
cmd.Stderr = &stderr
cmd.Stdin = w
cmd.Dir = tempName

err = cmd.Run()
if err != nil {
err = errors.Wrapf(err, "failed to run command: %s, stderr: %s", stringifyCmd(cmd), stderr.String())
errCh <- err
} else {
errCh <- nil
}
}()

err = <-errCh
if err != nil {
return errors.Wrapf(err, "failed to run command: %s, stderr: %s", stringifyCmd(cmd), stderr.String())
return errors.Wrap(err, "failed to download layer from S3")
}

decompressionDuration := time.Since(startTime)
duration := time.Since(startTime)

logger.WithField("duration", decompressionDuration).Info("the layer has been decompressed")
logger.WithField("duration", duration).Info("the layer has been downloaded and decompressed")

if !o.EnableRamfs {
err = os.RemoveAll(destinationPath)
Expand Down

0 comments on commit 7a3271c

Please sign in to comment.