Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(cli): avoid panic while updating progress #1597

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ help:
prepare: git-env install-tools go-vendor ## Initialize the go environment

.PHONY: test
test: prepare ## Run all go-sdk tests
test: prepare test-only ## Run all go-sdk tests

.PHONY: test-only
test-only: ## Run all go-sdk tests only (without prepare)
$(eval PACKAGES := $(shell go list ./... | grep -v integration))
gotestsum -f testname --rerun-fails=3 --packages="$(PACKAGES)" \
-- -v -cover -run=$(regex) -coverprofile=$(COVERAGEOUT) $(PACKAGES)
Expand Down
29 changes: 16 additions & 13 deletions cli/cmd/component.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,15 @@ func installComponent(args []string) (err error) {
return
}

installedVersion := component.InstalledVersion()
if installedVersion != nil {
return errors.Errorf(
"component %s is already installed. To upgrade run '%s'",
color.HiYellowString(componentName),
color.HiGreenString("lacework component upgrade %s", componentName),
)
}

cli.OutputChecklist(successIcon, fmt.Sprintf("Component %s found\n", component.Name))

cli.StartProgress(fmt.Sprintf("Staging component %s...", componentName))
Expand Down Expand Up @@ -1139,7 +1148,7 @@ func downloadProgress(complete chan int8, path string, sizeB int64) {
spinnerSuffix string = ""
)

if !cli.nonInteractive {
if cli.spinner != nil {
spinnerSuffix = cli.spinner.Suffix
}

Expand All @@ -1163,24 +1172,18 @@ func downloadProgress(complete chan int8, path string, sizeB int64) {
mb := float64(size) / (1 << 20)

if mb > previous {
if !cli.nonInteractive {
cli.spinner.Suffix = fmt.Sprintf("%s Downloaded: %.0fmb", spinnerSuffix, mb)
} else {
cli.OutputHuman("..Downloaded: %.0fmb\n", mb)
}

sizeString := fmt.Sprintf("%.0fmb", mb)
cli.Log.Infow("downloading component", "size", sizeString)
cli.StartProgress(fmt.Sprintf("%s (%s downloaded)", spinnerSuffix, sizeString))
previous = mb
}
} else {
percent := float64(size) / float64(sizeB) * 100

if percent > previous {
if !cli.nonInteractive {
cli.spinner.Suffix = fmt.Sprintf("%s Downloaded: %.0f%s", spinnerSuffix, percent, "%")
} else {
cli.OutputHuman("..Downloaded: %.0f%s\n", percent, "%")
}

percentString := fmt.Sprintf("%.0f%%", percent)
cli.Log.Infow("downloading component", "percent", percentString)
cli.StartProgress(fmt.Sprintf("%s (%s downloaded)", spinnerSuffix, percentString))
previous = percent
}
}
Expand Down
5 changes: 5 additions & 0 deletions integration/component_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ func TestCDKComponentInstall(t *testing.T) {
out = run(t, dir, "component-example")
assert.Contains(t, out, "component")

outBytes, errBytes, exitcode := LaceworkCLIWithHome(dir, "component", "install", "component-example")
assert.NotContains(t, outBytes.String(), "Installation completed.", "STDOUT should be empty")
assert.Contains(t, errBytes.String(), "already installed")
assert.Equal(t, 1, exitcode, "EXITCODE is not the expected one")

cleanup(dir)
}

Expand Down
40 changes: 18 additions & 22 deletions lwcomponent/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,13 @@ func (c *Catalog) GetComponent(name string) (*CDKComponent, error) {
return &component, nil
}

func (c *Catalog) ListComponentVersions(component *CDKComponent) (versions []*semver.Version, err error) {
func (c *Catalog) ListComponentVersions(component *CDKComponent) ([]*semver.Version, error) {
if component.ApiInfo == nil {
err = errors.Errorf("component '%s' api info not available", component.Name)
return
return nil, errors.Errorf("component '%s' api info not available", component.Name)
}

versions = component.ApiInfo.AllVersions
if versions != nil {
return
if component.ApiInfo.AllVersions != nil {
return component.ApiInfo.AllVersions, nil
}

return listComponentVersions(c.client, component.ApiInfo.Id)
Expand Down Expand Up @@ -180,7 +178,7 @@ func (c *Catalog) Stage(
return
}

func (c *Catalog) Verify(component *CDKComponent) (err error) {
func (c *Catalog) Verify(component *CDKComponent) error {
path := filepath.Join(component.stage.Directory(), component.Name)

if operatingSystem == "windows" {
Expand All @@ -189,12 +187,12 @@ func (c *Catalog) Verify(component *CDKComponent) (err error) {

data, err := os.ReadFile(path)
if err != nil {
return
return err
}

sig, err := component.stage.Signature()
if err != nil {
return
return err
}

rootPublicKey := minisign.PublicKey{}
Expand All @@ -205,29 +203,29 @@ func (c *Catalog) Verify(component *CDKComponent) (err error) {
return verifySignature(rootPublicKey, data, sig)
}

func (c *Catalog) Install(component *CDKComponent) (err error) {
func (c *Catalog) Install(component *CDKComponent) error {
if component.stage == nil {
return errors.Errorf("component '%s' not staged", component.Name)
}

componentDir, err := componentDirectory(component.Name)
if err != nil {
return
return err
}

err = os.MkdirAll(componentDir, os.ModePerm)
if err != nil {
return
return err
}

err = component.stage.Commit(componentDir)
if err != nil {
return
return err
}

component.HostInfo, err = NewHostInfo(componentDir, component.Description, component.Type)
if err != nil {
return
return err
}

path := filepath.Join(componentDir, component.Name)
Expand All @@ -243,27 +241,25 @@ func (c *Catalog) Install(component *CDKComponent) (err error) {
}
}

return
return nil
}

// Delete a CDKComponent
//
// Remove the Component install directory and all sub-directory. This function will not return an
// error if the Component is not installed.
func (c *Catalog) Delete(component *CDKComponent) (err error) {
func (c *Catalog) Delete(component *CDKComponent) error {
componentDir, err := componentDirectory(component.Name)
if err != nil {
return
return err
}

_, err = os.Stat(componentDir)
if err != nil {
return errors.Errorf("component not installed. Try running 'lacework component install %s'", component.Name)
}

os.RemoveAll(componentDir)

return
return os.RemoveAll(componentDir)
}

func NewCatalog(
Expand Down Expand Up @@ -412,7 +408,7 @@ func LoadLocalComponents() (components map[string]CDKComponent, err error) {
return
}

func listComponentVersions(client *api.Client, componentId int32) (versions []*semver.Version, err error) {
func listComponentVersions(client *api.Client, componentId int32) ([]*semver.Version, error) {
response, err := client.V2.Components.ListComponentVersions(componentId, operatingSystem, architecture)
if err != nil {
return nil, err
Expand All @@ -424,7 +420,7 @@ func listComponentVersions(client *api.Client, componentId int32) (versions []*s
rawVersions = response.Data[0].Versions
}

versions = make([]*semver.Version, len(rawVersions))
versions := make([]*semver.Version, len(rawVersions))

for idx, v := range rawVersions {
ver, err := semver.NewVersion(v)
Expand Down
18 changes: 8 additions & 10 deletions lwcomponent/cdk_component.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,34 +124,32 @@ func (c *CDKComponent) EnterDevMode() error {
return nil
}

func (c *CDKComponent) InstalledVersion() (version *semver.Version) {
var err error

func (c *CDKComponent) InstalledVersion() *semver.Version {
if c.HostInfo != nil {
version, err = c.HostInfo.Version()
version, err := c.HostInfo.Version()
if err == nil {
return
return version
}

if componentDir, err := c.Dir(); err == nil {
if devInfo, err := newDevInfo(componentDir); err == nil {
version, err = semver.NewVersion(devInfo.Version)
if err == nil {
return
return version
}
}
}
}

return
return nil
}

func (c *CDKComponent) LatestVersion() (version *semver.Version) {
func (c *CDKComponent) LatestVersion() *semver.Version {
if c.ApiInfo != nil {
version = c.ApiInfo.Version
return c.ApiInfo.Version
}

return
return nil
}

func (c *CDKComponent) PrintSummary() []string {
Expand Down
30 changes: 23 additions & 7 deletions lwcomponent/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package lwcomponent

import (
"fmt"
"os"
"strconv"
"time"
Expand All @@ -32,12 +31,12 @@ const (
DefaultMaxRetry = 3
)

var log = lwlogger.New("INFO")
var log = lwlogger.New("INFO").Sugar()

// Retry 3 times (4 requests total)
// Resty default RetryWaitTime is 100ms
// Exponential backoff to a maximum of RetryWaitTime of 2s
func DownloadFile(path string, url string) (err error) {
func DownloadFile(path string, url string) error {
client := resty.New()

download_timeout := os.Getenv("CDK_DOWNLOAD_TIMEOUT_MINUTES")
Expand All @@ -52,14 +51,31 @@ func DownloadFile(path string, url string) (err error) {
client.SetRetryCount(DefaultMaxRetry)

client.OnError(func(req *resty.Request, err error) {
fields := []interface{}{
"raw_error", err,
}

if v, ok := err.(*resty.ResponseError); ok {
log.Warn(fmt.Sprintf("Failed to download component: %s: %s", v.Response.Body(), v.Err))

fields = append(fields, "response_body", string(v.Response.Body()))

if v.Response.Request != nil {
trace := v.Response.Request.TraceInfo()
fields = append(fields, "trace_info", trace)
}

if v.Err != nil {
fields = append(fields, "response_error", v.Err.Error())
}
}

log.Warn(fmt.Sprintf("Failed to download component: %s", err.Error()))
log.Warnw("Failed to download component", fields...)
})

_, err = client.R().SetOutput(path).Get(url)
_, err := client.R().
EnableTrace().
SetOutput(path).
Get(url)

return
return err
}
33 changes: 27 additions & 6 deletions lwcomponent/http_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package lwcomponent_test
package lwcomponent

import (
"fmt"
Expand All @@ -7,8 +7,10 @@ import (
"os"
"testing"

"github.com/lacework/go-sdk/lwcomponent"
"github.com/stretchr/testify/assert"

capturer "github.com/lacework/go-sdk/internal/capturer"
"github.com/lacework/go-sdk/lwlogger"
)

func TestDownloadFile(t *testing.T) {
Expand All @@ -33,7 +35,7 @@ func TestDownloadFile(t *testing.T) {
})

t.Run("happy path", func(t *testing.T) {
err = lwcomponent.DownloadFile(file.Name(), fmt.Sprintf("%s%s", server.URL, urlPath))
err = DownloadFile(file.Name(), fmt.Sprintf("%s%s", server.URL, urlPath))
assert.Nil(t, err)

buf, err := os.ReadFile(file.Name())
Expand All @@ -53,14 +55,33 @@ func TestDownloadFile(t *testing.T) {
}
})

err = lwcomponent.DownloadFile(file.Name(), fmt.Sprintf("%s%s", server.URL, "/err"))
logsCaptured := capturer.CaptureOutput(func() {
log = lwlogger.New("INFO").Sugar()
err = DownloadFile(file.Name(), fmt.Sprintf("%s%s", server.URL, "/err"))
})
assert.NotNil(t, err)
assert.Equal(t, lwcomponent.DefaultMaxRetry+1, count)
assert.Equal(t, DefaultMaxRetry+1, count)

assert.Contains(t, logsCaptured, "WARN RESTY Get")
assert.Contains(t, logsCaptured, "/err\": EOF")
assert.Contains(t, logsCaptured, "Attempt 4") // the fifth attempt will error
assert.Contains(t, logsCaptured, "ERROR RESTY Get")
assert.Contains(t, logsCaptured, "Failed to download component")
assert.Contains(t, logsCaptured, "trace_info")
})

t.Run("url error", func(t *testing.T) {
err = lwcomponent.DownloadFile(file.Name(), "")
logsCaptured := capturer.CaptureOutput(func() {
log = lwlogger.New("INFO").Sugar()
err = DownloadFile(file.Name(), "")
})
assert.NotNil(t, err)
assert.False(t, os.IsTimeout(err))

assert.Contains(t, logsCaptured, "WARN RESTY Get")
assert.Contains(t, logsCaptured, "Attempt 4") // the fifth attempt will error
assert.Contains(t, logsCaptured, "ERROR RESTY Get")
assert.Contains(t, logsCaptured, "Failed to download component")
assert.Contains(t, logsCaptured, "trace_info")
})
}
Loading
Loading