Skip to content

Commit

Permalink
fix the issue that oras push and pull do not show progress bar
Browse files Browse the repository at this point in the history
Signed-off-by: jason yang <jasonyangshadow@gmail.com>
  • Loading branch information
JasonYangShadow committed Apr 1, 2024
1 parent 53e1e22 commit 31a605d
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 32 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ For older changes see the [archived Singularity change log](https://github.com/a
- Make 'apptainer build' work with signed Docker containers.
- Stopped binding over the default timezone in the container with the host's timezone,
which led to unexpected behavior if the application changed timezones.
- Fixed the issue that ORAS push and pull do not show progress bar.

## v1.3.0 - \[2024-03-12\]

Expand Down
72 changes: 42 additions & 30 deletions internal/pkg/client/oras/oras.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ import (
"github.com/google/go-containerregistry/pkg/v1/empty"
"github.com/google/go-containerregistry/pkg/v1/layout"
"github.com/google/go-containerregistry/pkg/v1/remote"
"golang.org/x/term"
)

// DownloadImage downloads a SIF image specified by an oci reference to a file using the included credentials
func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) error {
im, err := remoteImage(ref, ociAuth, noHTTPS)
func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, pb *client.DownloadProgressBar) error {
im, err := remoteImage(ref, ociAuth, noHTTPS, pb)
if err != nil {
return err
}
Expand Down Expand Up @@ -87,26 +88,7 @@ func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.Dock
}
defer outFile.Close()

// progressbar
in, out := io.Pipe()
mwriter := io.MultiWriter(outFile, out)
wpb := &writerWithProgressBar{
Writer: mwriter,
ProgressBar: &client.DownloadProgressBar{},
}
wpb.Init(layer.Size)

go func() {
_, err := io.Copy(io.Discard, in)
if err != nil {
pb.Abort(true)
in.CloseWithError(err)
}
pb.Wait()
in.Close()
}()

_, err = io.Copy(wpb, blob)
_, err = io.Copy(outFile, blob)
if err != nil {
return err
}
Expand Down Expand Up @@ -146,10 +128,36 @@ func UploadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.Docker
return err
}

authOptn := AuthOptn(ociAuth)
updates := make(chan v1.Update, 1)
go showProgressBar(updates)
return remote.Write(ir, im, authOptn, remote.WithUserAgent(useragent.Value()), remote.WithProgress(updates))
remoteOpts := []remote.Option{AuthOptn(ociAuth), remote.WithUserAgent(useragent.Value())}
if term.IsTerminal(2) {
pb := &client.DownloadProgressBar{}
progChan := make(chan v1.Update, 1)
go func() {
var total int64
soFar := int64(0)
for {
// The following is concurrency-safe because this is the only
// goroutine that's going to be reading progChan updates.
update := <-progChan
if update.Error != nil {
pb.Abort(false)
return
}
if update.Total != total {
pb.Init(update.Total)
total = update.Total
}
pb.IncrBy(int(update.Complete - soFar))
soFar = update.Complete
if soFar >= total {
pb.Wait()
return
}
}
}()
remoteOpts = append(remoteOpts, remote.WithProgress(progChan))
}
return remote.Write(ir, im, remoteOpts...)
}

// ensureSIF checks for a SIF image at filepath and returns an error if it is not, or an error is encountered
Expand All @@ -169,7 +177,7 @@ func ensureSIF(filepath string) error {

// RefHash returns the digest of the SIF layer of the OCI manifest for supplied ref
func RefHash(ctx context.Context, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) (v1.Hash, error) {
im, err := remoteImage(ref, ociAuth, noHTTPS)
im, err := remoteImage(ref, ociAuth, noHTTPS, nil)
if err != nil {
return v1.Hash{}, err
}
Expand Down Expand Up @@ -227,7 +235,7 @@ func sha256sum(r io.Reader) (result string, nBytes int64, err error) {
}

// remoteImage returns a v1.Image for the provided remote ref.
func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) (v1.Image, error) {
func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, pb *client.DownloadProgressBar) (v1.Image, error) {
ref = strings.TrimPrefix(ref, "oras://")
ref = strings.TrimPrefix(ref, "//")

Expand All @@ -240,8 +248,12 @@ func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) (
if err != nil {
return nil, fmt.Errorf("invalid reference %q: %w", ref, err)
}
authOptn := AuthOptn(ociAuth)
im, err := remote.Image(ir, authOptn)
remoteOpts := []remote.Option{AuthOptn(ociAuth)}
if pb != nil {
rt := client.NewRoundTripper(nil, pb)
remoteOpts = append(remoteOpts, remote.WithTransport(rt))
}
im, err := remote.Image(ir, remoteOpts...)
if err != nil {
return nil, err
}
Expand Down
10 changes: 8 additions & 2 deletions internal/pkg/client/oras/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ import (
"os"

"github.com/apptainer/apptainer/internal/pkg/cache"
"github.com/apptainer/apptainer/internal/pkg/client"
"github.com/apptainer/apptainer/internal/pkg/util/fs"
"github.com/apptainer/apptainer/pkg/sylog"
ocitypes "github.com/containers/image/v5/types"
"golang.org/x/term"
)

// pull will pull an oras image into the cache if directTo="", or a specific file if directTo is set.
Expand All @@ -27,9 +29,13 @@ func pull(ctx context.Context, imgCache *cache.Handle, directTo, pullFrom string
return "", fmt.Errorf("failed to get checksum for %s: %s", pullFrom, err)
}

var pb *client.DownloadProgressBar
if term.IsTerminal(2) {
pb = &client.DownloadProgressBar{}
}
if directTo != "" {
sylog.Infof("Downloading oras image")
if err := DownloadImage(ctx, directTo, pullFrom, ociAuth, noHTTPS); err != nil {
if err := DownloadImage(ctx, directTo, pullFrom, ociAuth, noHTTPS, pb); err != nil {
return "", fmt.Errorf("unable to Download Image: %v", err)
}
imagePath = directTo
Expand All @@ -43,7 +49,7 @@ func pull(ctx context.Context, imgCache *cache.Handle, directTo, pullFrom string
if !cacheEntry.Exists {
sylog.Infof("Downloading oras image")

if err := DownloadImage(ctx, cacheEntry.TmpPath, pullFrom, ociAuth, noHTTPS); err != nil {
if err := DownloadImage(ctx, cacheEntry.TmpPath, pullFrom, ociAuth, noHTTPS, pb); err != nil {
return "", fmt.Errorf("unable to Download Image: %v", err)
}
if cacheFileHash, err := ImageHash(cacheEntry.TmpPath); err != nil {
Expand Down
75 changes: 75 additions & 0 deletions internal/pkg/client/progress_roundtrip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Contributors to the Apptainer project, established as
// Apptainer a Series of LF Projects LLC.
// For website terms of use, trademark policy, privacy policy and other
// project policies see https://lfprojects.org/policies

// Copyright (c) 2023, Sylabs Inc. All rights reserved.
// This software is licensed under a 3-clause BSD license. Please consult the
// LICENSE.md file distributed with the sources of this project regarding your
// rights to use or distribute this software.

package client

import (
"io"
"net/http"
)

const contentSizeThreshold = 1024

type RoundTripper struct {
inner http.RoundTripper
pb *DownloadProgressBar
}

func NewRoundTripper(inner http.RoundTripper, pb *DownloadProgressBar) *RoundTripper {
if inner == nil {
inner = http.DefaultTransport
}

rt := RoundTripper{
inner: inner,
pb: pb,
}

return &rt
}

type rtReadCloser struct {
inner io.ReadCloser
pb *DownloadProgressBar
}

func (r *rtReadCloser) Read(p []byte) (int, error) {
return r.inner.Read(p)
}

func (r *rtReadCloser) Close() error {
err := r.inner.Close()
if err == nil {
r.pb.Wait()
} else {
r.pb.Abort(false)
}

return err
}

func (t *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if t.pb != nil && req.Body != nil && req.ContentLength >= contentSizeThreshold {
t.pb.Init(req.ContentLength)
req.Body = &rtReadCloser{
inner: t.pb.bar.ProxyReader(req.Body),
pb: t.pb,
}
}
resp, err := t.inner.RoundTrip(req)
if t.pb != nil && resp != nil && resp.Body != nil && resp.ContentLength >= contentSizeThreshold {
t.pb.Init(resp.ContentLength)
resp.Body = &rtReadCloser{
inner: t.pb.bar.ProxyReader(resp.Body),
pb: t.pb,
}
}
return resp, err
}

0 comments on commit 31a605d

Please sign in to comment.