From 881adbd7c6688791838077bf4e50ce22e6cb8532 Mon Sep 17 00:00:00 2001 From: jason yang Date: Mon, 1 Apr 2024 15:53:08 +0900 Subject: [PATCH] fix the issue that oras push and pull will not show progressbar Signed-off-by: jason yang --- CHANGELOG.md | 1 + internal/pkg/client/oras/oras.go | 72 +++++++++++++--------- internal/pkg/client/oras/progressbar.go | 69 --------------------- internal/pkg/client/oras/pull.go | 10 ++- internal/pkg/client/progress_roundtrip.go | 75 +++++++++++++++++++++++ 5 files changed, 126 insertions(+), 101 deletions(-) delete mode 100644 internal/pkg/client/oras/progressbar.go create mode 100644 internal/pkg/client/progress_roundtrip.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bdcabdefe..3a81ba3026 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ For older changes see the [archived Singularity change log](https://github.com/a - Make 'apptainer build' work with signed Docker containers. - Stopped binding over the default timezone in the container with the host's timezone, which led to unexpected behavior if the application changed timezones. +- Fixed the issue that ORAS push and pull do not show progress bar. ## v1.3.0 - \[2024-03-12\] diff --git a/internal/pkg/client/oras/oras.go b/internal/pkg/client/oras/oras.go index fa0a1178a3..d373f305c9 100644 --- a/internal/pkg/client/oras/oras.go +++ b/internal/pkg/client/oras/oras.go @@ -29,11 +29,12 @@ import ( "github.com/google/go-containerregistry/pkg/v1/empty" "github.com/google/go-containerregistry/pkg/v1/layout" "github.com/google/go-containerregistry/pkg/v1/remote" + "golang.org/x/term" ) // DownloadImage downloads a SIF image specified by an oci reference to a file using the included credentials -func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) error { - im, err := remoteImage(ref, ociAuth, noHTTPS) +func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, pb *client.DownloadProgressBar) error { + im, err := remoteImage(ref, ociAuth, noHTTPS, pb) if err != nil { return err } @@ -87,26 +88,7 @@ func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.Dock } defer outFile.Close() - // progressbar - in, out := io.Pipe() - mwriter := io.MultiWriter(outFile, out) - wpb := &writerWithProgressBar{ - Writer: mwriter, - ProgressBar: &client.DownloadProgressBar{}, - } - wpb.Init(layer.Size) - - go func() { - _, err := io.Copy(io.Discard, in) - if err != nil { - pb.Abort(true) - in.CloseWithError(err) - } - pb.Wait() - in.Close() - }() - - _, err = io.Copy(wpb, blob) + _, err = io.Copy(outFile, blob) if err != nil { return err } @@ -146,10 +128,36 @@ func UploadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.Docker return err } - authOptn := AuthOptn(ociAuth) - updates := make(chan v1.Update, 1) - go showProgressBar(updates) - return remote.Write(ir, im, authOptn, remote.WithUserAgent(useragent.Value()), remote.WithProgress(updates)) + remoteOpts := []remote.Option{AuthOptn(ociAuth), remote.WithUserAgent(useragent.Value())} + if term.IsTerminal(2) { + pb := &client.DownloadProgressBar{} + progChan := make(chan v1.Update, 1) + go func() { + var total int64 + soFar := int64(0) + for { + // The following is concurrency-safe because this is the only + // goroutine that's going to be reading progChan updates. + update := <-progChan + if update.Error != nil { + pb.Abort(false) + return + } + if update.Total != total { + pb.Init(update.Total) + total = update.Total + } + pb.IncrBy(int(update.Complete - soFar)) + soFar = update.Complete + if soFar >= total { + pb.Wait() + return + } + } + }() + remoteOpts = append(remoteOpts, remote.WithProgress(progChan)) + } + return remote.Write(ir, im, remoteOpts...) } // ensureSIF checks for a SIF image at filepath and returns an error if it is not, or an error is encountered @@ -169,7 +177,7 @@ func ensureSIF(filepath string) error { // RefHash returns the digest of the SIF layer of the OCI manifest for supplied ref func RefHash(ctx context.Context, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) (v1.Hash, error) { - im, err := remoteImage(ref, ociAuth, noHTTPS) + im, err := remoteImage(ref, ociAuth, noHTTPS, nil) if err != nil { return v1.Hash{}, err } @@ -227,7 +235,7 @@ func sha256sum(r io.Reader) (result string, nBytes int64, err error) { } // remoteImage returns a v1.Image for the provided remote ref. -func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) (v1.Image, error) { +func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, pb *client.DownloadProgressBar) (v1.Image, error) { ref = strings.TrimPrefix(ref, "oras://") ref = strings.TrimPrefix(ref, "//") @@ -240,8 +248,12 @@ func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) ( if err != nil { return nil, fmt.Errorf("invalid reference %q: %w", ref, err) } - authOptn := AuthOptn(ociAuth) - im, err := remote.Image(ir, authOptn) + remoteOpts := []remote.Option{AuthOptn(ociAuth)} + if pb != nil { + rt := client.NewRoundTripper(nil, pb) + remoteOpts = append(remoteOpts, remote.WithTransport(rt)) + } + im, err := remote.Image(ir, remoteOpts...) if err != nil { return nil, err } diff --git a/internal/pkg/client/oras/progressbar.go b/internal/pkg/client/oras/progressbar.go deleted file mode 100644 index 42c82474f6..0000000000 --- a/internal/pkg/client/oras/progressbar.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) Contributors to the Apptainer project, established as -// Apptainer a Series of LF Projects LLC. -// For website terms of use, trademark policy, privacy policy and other -// project policies see https://lfprojects.org/policies -// Copyright (c) 2020, Control Command Inc. All rights reserved. -// Copyright (c) 2020-2022, Sylabs Inc. All rights reserved. -// This software is licensed under a 3-clause BSD license. Please consult the -// LICENSE.md file distributed with the sources of this project regarding your -// rights to use or distribute this software. - -// This file is to add progress bar support for oras protocol. -package oras - -import ( - "io" - - "github.com/apptainer/apptainer/internal/pkg/client" - libClient "github.com/apptainer/container-library-client/client" - v1 "github.com/google/go-containerregistry/pkg/v1" -) - -var pb *progressBar - -type progressBar struct { - libClient.ProgressBar - previous int64 -} - -func showProgressBar(updates chan v1.Update) error { - for update := range updates { - if update.Error != nil { - if pb != nil { - pb.Abort(true) - } - return update.Error - } - - if update.Complete == update.Total { - break - } - - if pb == nil { - pb = &progressBar{ - ProgressBar: &client.DownloadProgressBar{}, - previous: 0, - } - pb.Init(update.Total) - } - pb.IncrBy(int(update.Complete - pb.previous)) - pb.previous = update.Complete - } - - return nil -} - -type writerWithProgressBar struct { - io.Writer - libClient.ProgressBar -} - -func (w *writerWithProgressBar) Write(p []byte) (n int, err error) { - n, err = w.Writer.Write(p) - if err != nil { - w.Abort(true) - return n, err - } - w.IncrBy(n) - return n, err -} diff --git a/internal/pkg/client/oras/pull.go b/internal/pkg/client/oras/pull.go index 1d19fc01f8..58edc0a154 100644 --- a/internal/pkg/client/oras/pull.go +++ b/internal/pkg/client/oras/pull.go @@ -15,9 +15,11 @@ import ( "os" "github.com/apptainer/apptainer/internal/pkg/cache" + "github.com/apptainer/apptainer/internal/pkg/client" "github.com/apptainer/apptainer/internal/pkg/util/fs" "github.com/apptainer/apptainer/pkg/sylog" ocitypes "github.com/containers/image/v5/types" + "golang.org/x/term" ) // pull will pull an oras image into the cache if directTo="", or a specific file if directTo is set. @@ -27,9 +29,13 @@ func pull(ctx context.Context, imgCache *cache.Handle, directTo, pullFrom string return "", fmt.Errorf("failed to get checksum for %s: %s", pullFrom, err) } + var pb *client.DownloadProgressBar + if term.IsTerminal(2) { + pb = &client.DownloadProgressBar{} + } if directTo != "" { sylog.Infof("Downloading oras image") - if err := DownloadImage(ctx, directTo, pullFrom, ociAuth, noHTTPS); err != nil { + if err := DownloadImage(ctx, directTo, pullFrom, ociAuth, noHTTPS, pb); err != nil { return "", fmt.Errorf("unable to Download Image: %v", err) } imagePath = directTo @@ -43,7 +49,7 @@ func pull(ctx context.Context, imgCache *cache.Handle, directTo, pullFrom string if !cacheEntry.Exists { sylog.Infof("Downloading oras image") - if err := DownloadImage(ctx, cacheEntry.TmpPath, pullFrom, ociAuth, noHTTPS); err != nil { + if err := DownloadImage(ctx, cacheEntry.TmpPath, pullFrom, ociAuth, noHTTPS, pb); err != nil { return "", fmt.Errorf("unable to Download Image: %v", err) } if cacheFileHash, err := ImageHash(cacheEntry.TmpPath); err != nil { diff --git a/internal/pkg/client/progress_roundtrip.go b/internal/pkg/client/progress_roundtrip.go new file mode 100644 index 0000000000..b00a128ab1 --- /dev/null +++ b/internal/pkg/client/progress_roundtrip.go @@ -0,0 +1,75 @@ +// Copyright (c) Contributors to the Apptainer project, established as +// Apptainer a Series of LF Projects LLC. +// For website terms of use, trademark policy, privacy policy and other +// project policies see https://lfprojects.org/policies + +// Copyright (c) 2023, Sylabs Inc. All rights reserved. +// This software is licensed under a 3-clause BSD license. Please consult the +// LICENSE.md file distributed with the sources of this project regarding your +// rights to use or distribute this software. + +package client + +import ( + "io" + "net/http" +) + +const contentSizeThreshold = 1024 + +type RoundTripper struct { + inner http.RoundTripper + pb *DownloadProgressBar +} + +func NewRoundTripper(inner http.RoundTripper, pb *DownloadProgressBar) *RoundTripper { + if inner == nil { + inner = http.DefaultTransport + } + + rt := RoundTripper{ + inner: inner, + pb: pb, + } + + return &rt +} + +type rtReadCloser struct { + inner io.ReadCloser + pb *DownloadProgressBar +} + +func (r *rtReadCloser) Read(p []byte) (int, error) { + return r.inner.Read(p) +} + +func (r *rtReadCloser) Close() error { + err := r.inner.Close() + if err == nil { + r.pb.Wait() + } else { + r.pb.Abort(false) + } + + return err +} + +func (t *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if t.pb != nil && req.Body != nil && req.ContentLength >= contentSizeThreshold { + t.pb.Init(req.ContentLength) + req.Body = &rtReadCloser{ + inner: t.pb.bar.ProxyReader(req.Body), + pb: t.pb, + } + } + resp, err := t.inner.RoundTrip(req) + if t.pb != nil && resp != nil && resp.Body != nil && resp.ContentLength >= contentSizeThreshold { + t.pb.Init(resp.ContentLength) + resp.Body = &rtReadCloser{ + inner: t.pb.bar.ProxyReader(resp.Body), + pb: t.pb, + } + } + return resp, err +}