Skip to content

fix(cli): avoid panic while updating progress #1597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -61,7 +61,10 @@ help:
prepare: git-env install-tools go-vendor ## Initialize the go environment

.PHONY: test
test: prepare ## Run all go-sdk tests
test: prepare test-only ## Run all go-sdk tests

.PHONY: test-only
test-only: ## Run all go-sdk tests only (without prepare)
$(eval PACKAGES := $(shell go list ./... | grep -v integration))
gotestsum -f testname --rerun-fails=3 --packages="$(PACKAGES)" \
-- -v -cover -run=$(regex) -coverprofile=$(COVERAGEOUT) $(PACKAGES)
29 changes: 16 additions & 13 deletions cli/cmd/component.go
Original file line number Diff line number Diff line change
@@ -432,6 +432,15 @@ func installComponent(args []string) (err error) {
return
}

installedVersion := component.InstalledVersion()
if installedVersion != nil {
return errors.Errorf(
"component %s is already installed. To upgrade run '%s'",
color.HiYellowString(componentName),
color.HiGreenString("lacework component upgrade %s", componentName),
)
}

cli.OutputChecklist(successIcon, fmt.Sprintf("Component %s found\n", component.Name))

cli.StartProgress(fmt.Sprintf("Staging component %s...", componentName))
@@ -1139,7 +1148,7 @@ func downloadProgress(complete chan int8, path string, sizeB int64) {
spinnerSuffix string = ""
)

if !cli.nonInteractive {
if cli.spinner != nil {
spinnerSuffix = cli.spinner.Suffix
}

@@ -1163,24 +1172,18 @@ func downloadProgress(complete chan int8, path string, sizeB int64) {
mb := float64(size) / (1 << 20)

if mb > previous {
if !cli.nonInteractive {
cli.spinner.Suffix = fmt.Sprintf("%s Downloaded: %.0fmb", spinnerSuffix, mb)
} else {
cli.OutputHuman("..Downloaded: %.0fmb\n", mb)
}

sizeString := fmt.Sprintf("%.0fmb", mb)
cli.Log.Infow("downloading component", "size", sizeString)
cli.StartProgress(fmt.Sprintf("%s (%s downloaded)", spinnerSuffix, sizeString))
previous = mb
}
} else {
percent := float64(size) / float64(sizeB) * 100

if percent > previous {
if !cli.nonInteractive {
cli.spinner.Suffix = fmt.Sprintf("%s Downloaded: %.0f%s", spinnerSuffix, percent, "%")
} else {
cli.OutputHuman("..Downloaded: %.0f%s\n", percent, "%")
}

percentString := fmt.Sprintf("%.0f%%", percent)
cli.Log.Infow("downloading component", "percent", percentString)
cli.StartProgress(fmt.Sprintf("%s (%s downloaded)", spinnerSuffix, percentString))
previous = percent
}
}
5 changes: 5 additions & 0 deletions integration/component_test.go
Original file line number Diff line number Diff line change
@@ -109,6 +109,11 @@ func TestCDKComponentInstall(t *testing.T) {
out = run(t, dir, "component-example")
assert.Contains(t, out, "component")

outBytes, errBytes, exitcode := LaceworkCLIWithHome(dir, "component", "install", "component-example")
assert.NotContains(t, outBytes.String(), "Installation completed.", "STDOUT should be empty")
assert.Contains(t, errBytes.String(), "already installed")
assert.Equal(t, 1, exitcode, "EXITCODE is not the expected one")

cleanup(dir)
}

40 changes: 18 additions & 22 deletions lwcomponent/catalog.go
Original file line number Diff line number Diff line change
@@ -73,15 +73,13 @@ func (c *Catalog) GetComponent(name string) (*CDKComponent, error) {
return &component, nil
}

func (c *Catalog) ListComponentVersions(component *CDKComponent) (versions []*semver.Version, err error) {
func (c *Catalog) ListComponentVersions(component *CDKComponent) ([]*semver.Version, error) {
if component.ApiInfo == nil {
err = errors.Errorf("component '%s' api info not available", component.Name)
return
return nil, errors.Errorf("component '%s' api info not available", component.Name)
}

versions = component.ApiInfo.AllVersions
if versions != nil {
return
if component.ApiInfo.AllVersions != nil {
return component.ApiInfo.AllVersions, nil
}

return listComponentVersions(c.client, component.ApiInfo.Id)
@@ -180,7 +178,7 @@ func (c *Catalog) Stage(
return
}

func (c *Catalog) Verify(component *CDKComponent) (err error) {
func (c *Catalog) Verify(component *CDKComponent) error {
path := filepath.Join(component.stage.Directory(), component.Name)

if operatingSystem == "windows" {
@@ -189,12 +187,12 @@ func (c *Catalog) Verify(component *CDKComponent) (err error) {

data, err := os.ReadFile(path)
if err != nil {
return
return err
}

sig, err := component.stage.Signature()
if err != nil {
return
return err
}

rootPublicKey := minisign.PublicKey{}
@@ -205,29 +203,29 @@ func (c *Catalog) Verify(component *CDKComponent) (err error) {
return verifySignature(rootPublicKey, data, sig)
}

func (c *Catalog) Install(component *CDKComponent) (err error) {
func (c *Catalog) Install(component *CDKComponent) error {
if component.stage == nil {
return errors.Errorf("component '%s' not staged", component.Name)
}

componentDir, err := componentDirectory(component.Name)
if err != nil {
return
return err
}

err = os.MkdirAll(componentDir, os.ModePerm)
if err != nil {
return
return err
}

err = component.stage.Commit(componentDir)
if err != nil {
return
return err
}

component.HostInfo, err = NewHostInfo(componentDir, component.Description, component.Type)
if err != nil {
return
return err
}

path := filepath.Join(componentDir, component.Name)
@@ -243,27 +241,25 @@ func (c *Catalog) Install(component *CDKComponent) (err error) {
}
}

return
return nil
}

// Delete a CDKComponent
//
// Remove the Component install directory and all sub-directory. This function will not return an
// error if the Component is not installed.
func (c *Catalog) Delete(component *CDKComponent) (err error) {
func (c *Catalog) Delete(component *CDKComponent) error {
componentDir, err := componentDirectory(component.Name)
if err != nil {
return
return err
}

_, err = os.Stat(componentDir)
if err != nil {
return errors.Errorf("component not installed. Try running 'lacework component install %s'", component.Name)
}

os.RemoveAll(componentDir)

return
return os.RemoveAll(componentDir)
}

func NewCatalog(
@@ -412,7 +408,7 @@ func LoadLocalComponents() (components map[string]CDKComponent, err error) {
return
}

func listComponentVersions(client *api.Client, componentId int32) (versions []*semver.Version, err error) {
func listComponentVersions(client *api.Client, componentId int32) ([]*semver.Version, error) {
response, err := client.V2.Components.ListComponentVersions(componentId, operatingSystem, architecture)
if err != nil {
return nil, err
@@ -424,7 +420,7 @@ func listComponentVersions(client *api.Client, componentId int32) (versions []*s
rawVersions = response.Data[0].Versions
}

versions = make([]*semver.Version, len(rawVersions))
versions := make([]*semver.Version, len(rawVersions))

for idx, v := range rawVersions {
ver, err := semver.NewVersion(v)
18 changes: 8 additions & 10 deletions lwcomponent/cdk_component.go
Original file line number Diff line number Diff line change
@@ -124,34 +124,32 @@ func (c *CDKComponent) EnterDevMode() error {
return nil
}

func (c *CDKComponent) InstalledVersion() (version *semver.Version) {
var err error

func (c *CDKComponent) InstalledVersion() *semver.Version {
if c.HostInfo != nil {
version, err = c.HostInfo.Version()
version, err := c.HostInfo.Version()
if err == nil {
return
return version
}

if componentDir, err := c.Dir(); err == nil {
if devInfo, err := newDevInfo(componentDir); err == nil {
version, err = semver.NewVersion(devInfo.Version)
if err == nil {
return
return version
}
}
}
}

return
return nil
}

func (c *CDKComponent) LatestVersion() (version *semver.Version) {
func (c *CDKComponent) LatestVersion() *semver.Version {
if c.ApiInfo != nil {
version = c.ApiInfo.Version
return c.ApiInfo.Version
}

return
return nil
}

func (c *CDKComponent) PrintSummary() []string {
30 changes: 23 additions & 7 deletions lwcomponent/http.go
Original file line number Diff line number Diff line change
@@ -19,7 +19,6 @@
package lwcomponent

import (
"fmt"
"os"
"strconv"
"time"
@@ -32,12 +31,12 @@ const (
DefaultMaxRetry = 3
)

var log = lwlogger.New("INFO")
var log = lwlogger.New("INFO").Sugar()

// Retry 3 times (4 requests total)
// Resty default RetryWaitTime is 100ms
// Exponential backoff to a maximum of RetryWaitTime of 2s
func DownloadFile(path string, url string) (err error) {
func DownloadFile(path string, url string) error {
client := resty.New()

download_timeout := os.Getenv("CDK_DOWNLOAD_TIMEOUT_MINUTES")
@@ -52,14 +51,31 @@ func DownloadFile(path string, url string) (err error) {
client.SetRetryCount(DefaultMaxRetry)

client.OnError(func(req *resty.Request, err error) {
fields := []interface{}{
"raw_error", err,
}

if v, ok := err.(*resty.ResponseError); ok {
log.Warn(fmt.Sprintf("Failed to download component: %s: %s", v.Response.Body(), v.Err))

fields = append(fields, "response_body", string(v.Response.Body()))

if v.Response.Request != nil {
trace := v.Response.Request.TraceInfo()
fields = append(fields, "trace_info", trace)
}

if v.Err != nil {
fields = append(fields, "response_error", v.Err.Error())
}
}

log.Warn(fmt.Sprintf("Failed to download component: %s", err.Error()))
log.Warnw("Failed to download component", fields...)
})

_, err = client.R().SetOutput(path).Get(url)
_, err := client.R().
EnableTrace().
SetOutput(path).
Get(url)

return
return err
}
33 changes: 27 additions & 6 deletions lwcomponent/http_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package lwcomponent_test
package lwcomponent

import (
"fmt"
@@ -7,8 +7,10 @@ import (
"os"
"testing"

"github.com/lacework/go-sdk/lwcomponent"
"github.com/stretchr/testify/assert"

capturer "github.com/lacework/go-sdk/internal/capturer"
"github.com/lacework/go-sdk/lwlogger"
)

func TestDownloadFile(t *testing.T) {
@@ -33,7 +35,7 @@ func TestDownloadFile(t *testing.T) {
})

t.Run("happy path", func(t *testing.T) {
err = lwcomponent.DownloadFile(file.Name(), fmt.Sprintf("%s%s", server.URL, urlPath))
err = DownloadFile(file.Name(), fmt.Sprintf("%s%s", server.URL, urlPath))
assert.Nil(t, err)

buf, err := os.ReadFile(file.Name())
@@ -53,14 +55,33 @@ func TestDownloadFile(t *testing.T) {
}
})

err = lwcomponent.DownloadFile(file.Name(), fmt.Sprintf("%s%s", server.URL, "/err"))
logsCaptured := capturer.CaptureOutput(func() {
log = lwlogger.New("INFO").Sugar()
err = DownloadFile(file.Name(), fmt.Sprintf("%s%s", server.URL, "/err"))
})
assert.NotNil(t, err)
assert.Equal(t, lwcomponent.DefaultMaxRetry+1, count)
assert.Equal(t, DefaultMaxRetry+1, count)

assert.Contains(t, logsCaptured, "WARN RESTY Get")
assert.Contains(t, logsCaptured, "/err\": EOF")
assert.Contains(t, logsCaptured, "Attempt 4") // the fifth attempt will error
assert.Contains(t, logsCaptured, "ERROR RESTY Get")
assert.Contains(t, logsCaptured, "Failed to download component")
assert.Contains(t, logsCaptured, "trace_info")
})

t.Run("url error", func(t *testing.T) {
err = lwcomponent.DownloadFile(file.Name(), "")
logsCaptured := capturer.CaptureOutput(func() {
log = lwlogger.New("INFO").Sugar()
err = DownloadFile(file.Name(), "")
})
assert.NotNil(t, err)
assert.False(t, os.IsTimeout(err))

assert.Contains(t, logsCaptured, "WARN RESTY Get")
assert.Contains(t, logsCaptured, "Attempt 4") // the fifth attempt will error
assert.Contains(t, logsCaptured, "ERROR RESTY Get")
assert.Contains(t, logsCaptured, "Failed to download component")
assert.Contains(t, logsCaptured, "trace_info")
})
}
49 changes: 19 additions & 30 deletions lwcomponent/staging.go
Original file line number Diff line number Diff line change
@@ -93,53 +93,42 @@ func (s *stageTarGz) Filename() string {
return filepath.Base(s.artifactUrl.Path)
}

func (s *stageTarGz) Download(progressClosure func(filepath string, sizeB int64)) (err error) {
func (s *stageTarGz) Download(progressClosure func(filepath string, sizeB int64)) error {
fileName := filepath.Base(s.artifactUrl.Path)

path := filepath.Join(s.dir, fileName)

_, err = os.Create(path)
if err != nil {
return
if _, err := os.Create(path); err != nil {
return err
}

go progressClosure(path, s.size*1024)

err = DownloadFile(path, s.artifactUrl.String())
if err != nil {
return
}

return
return DownloadFile(path, s.artifactUrl.String())
}

func (s *stageTarGz) Signature() (sig []byte, err error) {
_, err = os.Stat(s.dir)
func (s *stageTarGz) Signature() ([]byte, error) {
_, err := os.Stat(s.dir)
if os.IsNotExist(err) {
err = errors.New("component not staged")
return
return nil, errors.New("component not staged")
}

path := filepath.Join(s.dir, SignatureFile)
if !file.FileExists(path) {
err = errors.New("missing .signature file")
return
return nil, errors.New("missing .signature file")
}

sig, err = os.ReadFile(path)
sig, err := os.ReadFile(path)
if err != nil {
return
return sig, err
}

// Artifact signature may or may not be b64encoded
decoded_sig, err := base64.StdEncoding.DecodeString(string(sig))
if err == nil {
sig = decoded_sig
return decoded_sig, nil
}

err = nil

return
return sig, nil
}

func (s *stageTarGz) Unpack() (err error) {
@@ -164,10 +153,10 @@ func (s *stageTarGz) Unpack() (err error) {
return nil
}

func (s *stageTarGz) Validate() (err error) {
func (s *stageTarGz) Validate() error {
data, err := os.ReadFile(filepath.Join(s.dir, VersionFile))
if err != nil {
return
return err
}

version := string(data)
@@ -191,7 +180,7 @@ func (s *stageTarGz) Validate() (err error) {
return errors.Errorf("missing file '%s'", path)
}

return
return nil
}

// Inflate GZip file.
@@ -221,10 +210,10 @@ func gunzip(source string, target string) (err error) {
return
}

func unTar(tarball string, dir string) (err error) {
func unTar(tarball string, dir string) error {
reader, err := os.Open(tarball)
if err != nil {
return
return err
}
defer reader.Close()

@@ -242,7 +231,7 @@ func unTar(tarball string, dir string) (err error) {

info := header.FileInfo()
if info.IsDir() {
if err = os.MkdirAll(path, info.Mode()); err != nil {
if err := os.MkdirAll(path, info.Mode()); err != nil {
return err
}
continue
@@ -260,5 +249,5 @@ func unTar(tarball string, dir string) (err error) {
}
}

return
return nil
}