From 7a3271cc4604b38dbee95e5d64401cea71f71d7d Mon Sep 17 00:00:00 2001 From: yetone Date: Wed, 4 Dec 2024 14:28:31 +0800 Subject: [PATCH] feat: streaming decompression for presigned download (#146) --- bento-image-snapshotter/fs/s3.go | 309 +++++++++++++++++++++++-------- 1 file changed, 229 insertions(+), 80 deletions(-) diff --git a/bento-image-snapshotter/fs/s3.go b/bento-image-snapshotter/fs/s3.go index be86523..c6e4157 100644 --- a/bento-image-snapshotter/fs/s3.go +++ b/bento-image-snapshotter/fs/s3.go @@ -10,8 +10,10 @@ import ( "os/exec" "path/filepath" "runtime" + "slices" "strconv" "strings" + "sync" "time" "github.com/aws/aws-sdk-go-v2/service/s3" @@ -75,6 +77,13 @@ func (o S3FileSystem) getObjectSize(ctx context.Context, bucketName, objectKey s } func (o S3FileSystem) Mount(ctx context.Context, mountpoint string, labels map[string]string) error { + if mounted, err := o.isMounted(mountpoint); err != nil { + return errors.Wrap(err, "failed to check if mountpoint is mounted") + } else if mounted { + log.G(ctx).WithField("mountpoint", mountpoint).Info("already mounted") + return nil + } + bucketName := labels[common.DescriptorAnnotationBucket] objectKey := labels[common.DescriptorAnnotationObjectKey] logger := log.G(ctx).WithField("bucket", bucketName).WithField("key", objectKey).WithField("mountpoint", mountpoint) @@ -126,7 +135,7 @@ func (o S3FileSystem) Mount(ctx context.Context, mountpoint string, labels map[s return errors.Wrap(err, "failed to symlink destinationPath to mountpoint") } } - return nil + return createMountedFlag(mountpoint) } func (o S3FileSystem) Check(ctx context.Context, mountpoint string, labels map[string]string) error { @@ -150,6 +159,45 @@ func getRamfsPath(path string) string { return path + ".ramfs" } +func getMountedFlagPath(mountpoint string) string { + baseDir := filepath.Dir(mountpoint) + return filepath.Join(baseDir, ".mounted") +} + +func createMountedFlag(mountpoint string) error { + mountedFlagPath := getMountedFlagPath(mountpoint) + file, err := os.Create(mountedFlagPath) + if err != nil { + return errors.Wrap(err, "failed to create mounted flag") + } + defer file.Close() + return nil +} + +func removeMountedFlag(mountpoint string) error { + mountedFlagPath := getMountedFlagPath(mountpoint) + if err := os.Remove(mountedFlagPath); err != nil { + if !os.IsNotExist(err) { + return errors.Wrap(err, "failed to remove mounted flag") + } + } + return nil +} + +func (o S3FileSystem) isMounted(mountpoint string) (bool, error) { + if o.EnableRamfs { + return isRamfsMounted(mountpoint) + } + mountedFlagPath := getMountedFlagPath(mountpoint) + if _, err := os.Stat(mountedFlagPath); err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, errors.Wrap(err, "failed to stat mounted flag") + } + return true, nil +} + func isRamfsMounted(path string) (bool, error) { ramfsPath := getRamfsPath(path) file, err := os.Open("/proc/mounts") @@ -208,7 +256,7 @@ func (o S3FileSystem) Unmount(ctx context.Context, mountpoint string) error { if err := os.RemoveAll(mountpoint); err != nil { return errors.Wrap(err, "failed to remove mountpoint") } - return nil + return removeMountedFlag(mountpoint) } func stringifyCmd(cmd *exec.Cmd) string { @@ -253,71 +301,13 @@ func downloadRange(ctx context.Context, presignedURL string, start, end int64, w return nil } -func parrallelDownload(ctx context.Context, presignedURL string, destPath string, parallelism int) error { - // use curl to get the size of the file - logger := log.G(ctx).WithField("url", presignedURL) - logger.Debug("getting the size of the file...") - var stderr bytes.Buffer - cmd := exec.CommandContext(ctx, "curl", "--silent", "--show-error", "--fail", "--head", "--output", "-", presignedURL) // nolint:gosec - cmd.Stderr = &stderr - output, err := cmd.Output() - if err != nil { - return errors.Wrapf(err, "failed to run command: %s, stderr: %s", stringifyCmd(cmd), stderr.String()) - } - var size int64 - lines := strings.Split(string(output), "\n") - for _, line := range lines { - k, v, found := strings.Cut(line, ":") - if !found { - continue - } - if strings.ToLower(strings.TrimSpace(k)) == "content-length" { - var err error - size, err = strconv.ParseInt(strings.TrimSpace(v), 10, 64) - if err != nil { - return errors.Wrapf(err, "failed to parse size: %s", string(output)) - } - logger.Debugf("size: %d", size) - break - } - } - - if size == 0 { - return errors.New("cannot get the file size from presigned url") - } - +func parrallelDownload(ctx context.Context, presignedURL string, writer io.WriterAt, size int64, parallelism int) error { partSize := size / int64(parallelism) if size%int64(parallelism) > 0 { partSize += 1 } - baseDir := filepath.Dir(destPath) - if err := os.MkdirAll(baseDir, 0755); err != nil { - return errors.Wrap(err, "failed to create base dir") - } - - // create the target file and falloc the target file - file, err := os.Create(destPath) - if err != nil { - return errors.Wrap(err, "failed to create target file") - } - defer file.Close() - - // use truncate to fallocate the target file - // seek to the first byte of the file - if _, err := file.Seek(0, 0); err != nil { - return errors.Wrap(err, "failed to seek the target file") - } - // truncate the file to the size - if err := file.Truncate(size); err != nil { - return errors.Wrap(err, "failed to truncate the target file") - } - // - // if err := syscall.Fallocate(int(file.Fd()), 0, 0, size); err != nil { - // return errors.Wrap(err, "failed to fallocate the target file") - // } - - var eg errgroup.Group + eg, ctx := errgroup.WithContext(ctx) for i := 0; i < parallelism; i++ { start := int64(i) * partSize end := start + partSize - 1 @@ -332,10 +322,127 @@ func parrallelDownload(ctx context.Context, presignedURL string, destPath string } } }() - return downloadRange(ctx, presignedURL, start, end, file) + err := downloadRange(ctx, presignedURL, start, end, writer) + if err != nil { + return errors.Wrap(err, "failed to download range") + } + return nil }) } - return errors.Wrap(eg.Wait(), "downloading") + + return errors.Wrap(eg.Wait(), "failed to download layer from S3") +} + +type DataRange struct { + Start int64 // inclusive + End int64 // inclusive +} + +type WriterAtReader struct { + file *os.File + readyDataRange []DataRange + eofReached bool + lock sync.RWMutex + cond *sync.Cond + curOffset int64 + size int64 +} + +func NewWriterAtReader(file *os.File, size int64) *WriterAtReader { + w := &WriterAtReader{ + file: file, + readyDataRange: make([]DataRange, 0), + size: size, + } + w.cond = sync.NewCond(&w.lock) + return w +} + +func (w *WriterAtReader) WriteAt(p []byte, off int64) (n int, err error) { + n, err = w.file.WriteAt(p, off) + if err != nil { + return n, errors.Wrap(err, "failed to write data to file") + } + + w.lock.Lock() + readyDataRange := append(w.readyDataRange, DataRange{Start: off, End: off + int64(n) - 1}) //nolint:gocritic + slices.SortFunc(readyDataRange, func(a, b DataRange) int { + return int(a.Start - b.Start) + }) + + // Merge overlapping or adjacent ranges + merged := make([]DataRange, 0) + if len(readyDataRange) > 0 { + current := readyDataRange[0] + for i := 1; i < len(readyDataRange); i++ { + if current.End+1 >= readyDataRange[i].Start { + if readyDataRange[i].End > current.End { + current.End = readyDataRange[i].End + } + } else { + merged = append(merged, current) + current = readyDataRange[i] + } + } + merged = append(merged, current) + } + + w.readyDataRange = merged + w.eofReached = w.isEOF() + w.cond.Broadcast() + w.lock.Unlock() + + return n, nil +} + +func (w *WriterAtReader) isEOF() bool { + return len(w.readyDataRange) == 1 && + w.readyDataRange[0].Start == 0 && + w.readyDataRange[0].End == w.size-1 +} + +func (w *WriterAtReader) dataReady(offset, size int64) bool { + for _, r := range w.readyDataRange { + if r.Start <= offset && r.End >= offset+size-1 { + return true + } + } + return false +} + +func (w *WriterAtReader) Read(p []byte) (int, error) { + w.lock.Lock() + for !w.eofReached && !w.dataReady(w.curOffset, int64(len(p))) { + w.cond.Wait() + } + + if w.eofReached && w.curOffset >= w.size { + w.lock.Unlock() + return 0, io.EOF + } + + readSize := int64(len(p)) + if w.curOffset+readSize > w.size { + readSize = w.size - w.curOffset + } + + w.lock.Unlock() + + n, err := w.file.ReadAt(p[:readSize], w.curOffset) + w.curOffset += int64(n) + + if err != nil && !errors.Is(err, io.EOF) { + return n, errors.Wrap(err, "failed to read data from file") + } + return n, nil +} + +func (w *WriterAtReader) Close() error { + w.lock.Lock() + w.eofReached = true + w.cond.Broadcast() + w.lock.Unlock() + return nil } func (o *S3FileSystem) downloadLayerFromS3(ctx context.Context, bucketName, layerKey string, destinationPath string) error { @@ -367,33 +474,75 @@ func (o *S3FileSystem) downloadLayerFromS3(ctx context.Context, bucketName, laye return errors.Wrap(err, "failed to get s3 presigned url") } + size, err := o.getObjectSize(ctx, bucketName, layerKey) + if err != nil { + return errors.Wrap(err, "failed to get object size") + } + coreNumber := runtime.NumCPU() - startTime := time.Now() - logger.Info("downloading layer from S3...") - err = parrallelDownload(ctx, presignedURL, downloadedFilePath, coreNumber) + baseDir := filepath.Dir(downloadedFilePath) + if err := os.MkdirAll(baseDir, 0755); err != nil { + return errors.Wrap(err, "failed to create base dir for downloaded file") + } + + // create the target file and falloc the target file + file, err := os.Create(downloadedFilePath) if err != nil { - return errors.Wrap(err, "failed to download layer from S3") + return errors.Wrap(err, "failed to create downloaded file") } + defer file.Close() - downloadDuration := time.Since(startTime) - logger.WithField("duration", downloadDuration).Info("the layer has been downloaded from S3") + // use truncate to fallocate the target file + // seek to the first byte of the file + if _, err := file.Seek(0, 0); err != nil { + return errors.Wrap(err, "failed to seek the target file") + } + // truncate the file to the size + if err := file.Truncate(size); err != nil { + return errors.Wrap(err, "failed to truncate the target file") + } - startTime = time.Now() - logger.Info("decompressing layer...") - var stderr bytes.Buffer - cmd := exec.CommandContext(ctx, "sh", "-c", fmt.Sprintf("pzstd -d -c %s | tar -xf -", downloadedFilePath)) // nolint:gosec - cmd.Stderr = &stderr - cmd.Dir = tempName + w := NewWriterAtReader(file, size) - err = cmd.Run() + startTime := time.Now() + logger.Info("downloading and streaming decompression layer from S3...") + errCh := make(chan error, 1) + + go func() { + defer func() { + _ = w.Close() + }() + err := parrallelDownload(ctx, presignedURL, w, size, coreNumber) + if err != nil { + errCh <- errors.Wrap(err, "failed to download layer from S3") + } + }() + + go func() { + var stderr bytes.Buffer + cmd := exec.CommandContext(ctx, "sh", "-c", "pzstd -d | tar -xf -") + cmd.Stderr = &stderr + cmd.Stdin = w + cmd.Dir = tempName + + err = cmd.Run() + if err != nil { + err = errors.Wrapf(err, "failed to run command: %s, stderr: %s", stringifyCmd(cmd), stderr.String()) + errCh <- err + } else { + errCh <- nil + } + }() + + err = <-errCh if err != nil { - return errors.Wrapf(err, "failed to run command: %s, stderr: %s", stringifyCmd(cmd), stderr.String()) + return errors.Wrap(err, "failed to download layer from S3") } - decompressionDuration := time.Since(startTime) + duration := time.Since(startTime) - logger.WithField("duration", decompressionDuration).Info("the layer has been decompressed") + logger.WithField("duration", duration).Info("the layer has been downloaded and decompressed") if !o.EnableRamfs { err = os.RemoveAll(destinationPath)