Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the issue that oras push and pull do not show progress bar #32

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ For older changes see the [archived Singularity change log](https://github.com/a
- Make 'apptainer build' work with signed Docker containers.
- Stopped binding over the default timezone in the container with the host's timezone,
which led to unexpected behavior if the application changed timezones.
- Fixed the issue that ORAS push and pull do not show progress bar.

## v1.3.0 - \[2024-03-12\]

Expand Down
72 changes: 42 additions & 30 deletions internal/pkg/client/oras/oras.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ import (
"github.com/google/go-containerregistry/pkg/v1/empty"
"github.com/google/go-containerregistry/pkg/v1/layout"
"github.com/google/go-containerregistry/pkg/v1/remote"
"golang.org/x/term"
)

// DownloadImage downloads a SIF image specified by an oci reference to a file using the included credentials
func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) error {
im, err := remoteImage(ref, ociAuth, noHTTPS)
func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, pb *client.DownloadProgressBar) error {
im, err := remoteImage(ref, ociAuth, noHTTPS, pb)
if err != nil {
return err
}
Expand Down Expand Up @@ -87,26 +88,7 @@ func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.Dock
}
defer outFile.Close()

// progressbar
in, out := io.Pipe()
mwriter := io.MultiWriter(outFile, out)
wpb := &writerWithProgressBar{
Writer: mwriter,
ProgressBar: &client.DownloadProgressBar{},
}
wpb.Init(layer.Size)

go func() {
_, err := io.Copy(io.Discard, in)
if err != nil {
pb.Abort(true)
in.CloseWithError(err)
}
pb.Wait()
in.Close()
}()

_, err = io.Copy(wpb, blob)
_, err = io.Copy(outFile, blob)
if err != nil {
return err
}
Expand Down Expand Up @@ -146,10 +128,36 @@ func UploadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.Docker
return err
}

authOptn := AuthOptn(ociAuth)
updates := make(chan v1.Update, 1)
go showProgressBar(updates)
return remote.Write(ir, im, authOptn, remote.WithUserAgent(useragent.Value()), remote.WithProgress(updates))
remoteOpts := []remote.Option{AuthOptn(ociAuth), remote.WithUserAgent(useragent.Value())}
if term.IsTerminal(2) {
pb := &client.DownloadProgressBar{}
progChan := make(chan v1.Update, 1)
go func() {
var total int64
soFar := int64(0)
for {
// The following is concurrency-safe because this is the only
// goroutine that's going to be reading progChan updates.
update := <-progChan
if update.Error != nil {
pb.Abort(false)
return
}
if update.Total != total {
pb.Init(update.Total)
total = update.Total
}
pb.IncrBy(int(update.Complete - soFar))
soFar = update.Complete
if soFar >= total {
pb.Wait()
return
}
}
}()
remoteOpts = append(remoteOpts, remote.WithProgress(progChan))
}
return remote.Write(ir, im, remoteOpts...)
}

// ensureSIF checks for a SIF image at filepath and returns an error if it is not, or an error is encountered
Expand All @@ -169,7 +177,7 @@ func ensureSIF(filepath string) error {

// RefHash returns the digest of the SIF layer of the OCI manifest for supplied ref
func RefHash(ctx context.Context, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) (v1.Hash, error) {
im, err := remoteImage(ref, ociAuth, noHTTPS)
im, err := remoteImage(ref, ociAuth, noHTTPS, nil)
if err != nil {
return v1.Hash{}, err
}
Expand Down Expand Up @@ -227,7 +235,7 @@ func sha256sum(r io.Reader) (result string, nBytes int64, err error) {
}

// remoteImage returns a v1.Image for the provided remote ref.
func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) (v1.Image, error) {
func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, pb *client.DownloadProgressBar) (v1.Image, error) {
ref = strings.TrimPrefix(ref, "oras://")
ref = strings.TrimPrefix(ref, "//")

Expand All @@ -240,8 +248,12 @@ func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) (
if err != nil {
return nil, fmt.Errorf("invalid reference %q: %w", ref, err)
}
authOptn := AuthOptn(ociAuth)
im, err := remote.Image(ir, authOptn)
remoteOpts := []remote.Option{AuthOptn(ociAuth)}
if pb != nil {
rt := client.NewRoundTripper(nil, pb)
remoteOpts = append(remoteOpts, remote.WithTransport(rt))
}
im, err := remote.Image(ir, remoteOpts...)
if err != nil {
return nil, err
}
Expand Down
69 changes: 0 additions & 69 deletions internal/pkg/client/oras/progressbar.go

This file was deleted.

10 changes: 8 additions & 2 deletions internal/pkg/client/oras/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ import (
"os"

"github.com/apptainer/apptainer/internal/pkg/cache"
"github.com/apptainer/apptainer/internal/pkg/client"
"github.com/apptainer/apptainer/internal/pkg/util/fs"
"github.com/apptainer/apptainer/pkg/sylog"
ocitypes "github.com/containers/image/v5/types"
"golang.org/x/term"
)

// pull will pull an oras image into the cache if directTo="", or a specific file if directTo is set.
Expand All @@ -27,9 +29,13 @@ func pull(ctx context.Context, imgCache *cache.Handle, directTo, pullFrom string
return "", fmt.Errorf("failed to get checksum for %s: %s", pullFrom, err)
}

var pb *client.DownloadProgressBar
if term.IsTerminal(2) {
pb = &client.DownloadProgressBar{}
}
if directTo != "" {
sylog.Infof("Downloading oras image")
if err := DownloadImage(ctx, directTo, pullFrom, ociAuth, noHTTPS); err != nil {
if err := DownloadImage(ctx, directTo, pullFrom, ociAuth, noHTTPS, pb); err != nil {
return "", fmt.Errorf("unable to Download Image: %v", err)
}
imagePath = directTo
Expand All @@ -43,7 +49,7 @@ func pull(ctx context.Context, imgCache *cache.Handle, directTo, pullFrom string
if !cacheEntry.Exists {
sylog.Infof("Downloading oras image")

if err := DownloadImage(ctx, cacheEntry.TmpPath, pullFrom, ociAuth, noHTTPS); err != nil {
if err := DownloadImage(ctx, cacheEntry.TmpPath, pullFrom, ociAuth, noHTTPS, pb); err != nil {
return "", fmt.Errorf("unable to Download Image: %v", err)
}
if cacheFileHash, err := ImageHash(cacheEntry.TmpPath); err != nil {
Expand Down
75 changes: 75 additions & 0 deletions internal/pkg/client/progress_roundtrip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Contributors to the Apptainer project, established as
// Apptainer a Series of LF Projects LLC.
// For website terms of use, trademark policy, privacy policy and other
// project policies see https://lfprojects.org/policies

// Copyright (c) 2023, Sylabs Inc. All rights reserved.
// This software is licensed under a 3-clause BSD license. Please consult the
// LICENSE.md file distributed with the sources of this project regarding your
// rights to use or distribute this software.

package client

import (
"io"
"net/http"
)

const contentSizeThreshold = 1024

type RoundTripper struct {
inner http.RoundTripper
pb *DownloadProgressBar
}

func NewRoundTripper(inner http.RoundTripper, pb *DownloadProgressBar) *RoundTripper {
if inner == nil {
inner = http.DefaultTransport
}

rt := RoundTripper{
inner: inner,
pb: pb,
}

return &rt
}

type rtReadCloser struct {
inner io.ReadCloser
pb *DownloadProgressBar
}

func (r *rtReadCloser) Read(p []byte) (int, error) {
return r.inner.Read(p)
}

func (r *rtReadCloser) Close() error {
err := r.inner.Close()
if err == nil {
r.pb.Wait()
} else {
r.pb.Abort(false)
}

return err
}

func (t *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if t.pb != nil && req.Body != nil && req.ContentLength >= contentSizeThreshold {
t.pb.Init(req.ContentLength)
req.Body = &rtReadCloser{
inner: t.pb.bar.ProxyReader(req.Body),
pb: t.pb,
}
}
resp, err := t.inner.RoundTrip(req)
if t.pb != nil && resp != nil && resp.Body != nil && resp.ContentLength >= contentSizeThreshold {
t.pb.Init(resp.ContentLength)
resp.Body = &rtReadCloser{
inner: t.pb.bar.ProxyReader(resp.Body),
pb: t.pb,
}
}
return resp, err
}
Loading