Skip to content

Commit

Permalink
Error redesign (#78)
Browse files Browse the repository at this point in the history
Signed-off-by: Kimmo Lehto <klehto@mirantis.com>

Better error design and usable error categories, such as ErrCantConnect which is a clear indicator that Connect() should not be retried. The oddly specific single-use errors are now wrapped under more meaningful top level errors.
  • Loading branch information
kke authored Nov 23, 2022
1 parent 026d077 commit 81e018d
Show file tree
Hide file tree
Showing 27 changed files with 408 additions and 258 deletions.
23 changes: 20 additions & 3 deletions cmd/rigtest/rigtest.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"errors"
"flag"
"fmt"
goos "os"
Expand Down Expand Up @@ -45,6 +46,18 @@ func (h *Host) LoadOS() error {
return nil
}

func retry(fn func() error) error {
var err error
for i := 0; i < 3; i++ {
err = fn()
if err == nil {
return nil
}
time.Sleep(2 * time.Second)
}
return nil
}

func main() {
dh := flag.String("host", "127.0.0.1", "target host [+ :port], can give multiple comma separated")
usr := flag.String("user", "root", "user name")
Expand Down Expand Up @@ -116,9 +129,13 @@ func main() {
}

for _, h := range hosts {
if err := h.Connect(); err != nil {
panic(err)
}
err := retry(func() error {
err := h.Connect()
if errors.Is(err, rig.ErrCantConnect) {
panic(err)
}
return err
})

if err := h.LoadOS(); err != nil {
panic(err)
Expand Down
25 changes: 11 additions & 14 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,16 @@
package rig

import (
"errors"
"fmt"
"strings"

"github.com/alessio/shellescape"
"github.com/creasty/defaults"
"github.com/google/shlex"
"github.com/k0sproject/rig/exec"
"github.com/k0sproject/rig/log"
)

// ErrNotConnected is returned when a connection is used when it is not connected
var ErrNotConnected = errors.New("not connected")

type client interface {
Connect() error
Disconnect()
Expand Down Expand Up @@ -126,7 +123,7 @@ func (c *Connection) IsConnected() bool {

func (c *Connection) checkConnected() error {
if !c.IsConnected() {
return fmt.Errorf("%s: %w", c, ErrNotConnected)
return ErrNotConnected
}

return nil
Expand Down Expand Up @@ -159,7 +156,7 @@ func (c Connection) Exec(cmd string, opts ...exec.Option) error {
}

if err := c.client.Exec(cmd, opts...); err != nil {
return fmt.Errorf("client exec: %w", err)
return ErrCommandFailed.Wrapf("client exec: %w", err)
}

return nil
Expand All @@ -180,12 +177,15 @@ func (c Connection) ExecOutput(cmd string, opts ...exec.Option) (string, error)
// Connect to the host and identify the operating system and sudo capability
func (c *Connection) Connect() error {
if c.client == nil {
_ = defaults.Set(c)
if err := defaults.Set(c); err != nil {
return ErrValidationFailed.Wrapf("set defaults: %w", err)
}
}

if err := c.client.Connect(); err != nil {
c.client = nil
return fmt.Errorf("client connect: %w", err)
log.Debugf("%s: failed to connect: %v", c, err)
return ErrNotConnected.Wrapf("client connect: %w", err)
}

if c.OSVersion == nil {
Expand Down Expand Up @@ -262,13 +262,10 @@ func (c *Connection) configureSudo() {
}
}

// ErrNoSudo is returned when the connection does not have sudo capability but it is required
var ErrNoSudo = errors.New("user is not an administrator and passwordless access elevation has not been configured")

// Sudo formats a command string to be run with elevated privileges
func (c Connection) Sudo(cmd string) (string, error) {
if c.sudofunc == nil {
return "", ErrNoSudo
return "", ErrSudoRequired.Wrapf("user is not an administrator and passwordless access elevation has not been configured")
}

return c.sudofunc(cmd), nil
Expand All @@ -295,7 +292,7 @@ func (c Connection) ExecInteractive(cmd string) error {
}

if err := c.client.ExecInteractive(cmd); err != nil {
return fmt.Errorf("client exec interactive: %w", err)
return ErrCommandFailed.Wrapf("client exec interactive: %w", err)
}

return nil
Expand All @@ -317,7 +314,7 @@ func (c Connection) Upload(src, dst string, opts ...exec.Option) error {
}

if err := c.client.Upload(src, dst, opts...); err != nil {
return fmt.Errorf("client upload: %w", err)
return ErrUploadFailed.Wrap(err)
}

return nil
Expand Down
2 changes: 0 additions & 2 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ func TestHostFunctions(t *testing.T) {
require.NoError(t, defaults.Set(&h))
require.Equal(t, "SSH", h.Protocol())
require.Equal(t, "127.0.0.1", h.Address())
_ = h.Connect()
h.Disconnect()
}

func TestOutputWriter(t *testing.T) {
Expand Down
20 changes: 20 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package rig

import (
"github.com/k0sproject/rig/errstring"
)

var (
ErrOS = errstring.New("local os") // ErrOS is returned when an action fails on local OS
ErrInvalidPath = errstring.New("invalid path") // ErrInvalidPath is returned when a path is invalid
ErrValidationFailed = errstring.New("validation failed") // ErrValidationFailed is returned when a validation fails
ErrSudoRequired = errstring.New("sudo required") // ErrSudoRequired is returned when sudo is required
ErrNotFound = errstring.New("not found") // ErrNotFound is returned when a resource is not found
ErrNotImplemented = errstring.New("not implemented") // ErrNotImplemented is returned when a feature is not implemented
ErrNotSupported = errstring.New("not supported") // ErrNotSupported is returned when a feature is not supported
ErrAuthFailed = errstring.New("authentication failed") // ErrAuthFailed is returned when authentication fails
ErrUploadFailed = errstring.New("upload failed") // ErrUploadFailed is returned when an upload fails
ErrNotConnected = errstring.New("not connected") // ErrNotConnected is returned when a connection is not established
ErrCantConnect = errstring.New("can't connect") // ErrCantConnect is returned when a connection is not established and retrying will fail
ErrCommandFailed = errstring.New("command failed") // ErrCommandFailed is returned when a command fails
)
75 changes: 75 additions & 0 deletions errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package rig

import (
"errors"
"testing"

"github.com/stretchr/testify/require"
)

func TestErrorStringer(t *testing.T) {
type testCase struct {
name string
err error
expected string
}

for _, scenario := range []testCase{
{
name: "non-wrapped error",
err: ErrOS,
expected: "local os",
},
{
name: "error wrapped in error",
err: ErrOS.Wrap(ErrInvalidPath),
expected: "local os: invalid path",
},
{
name: "string wrapped error",
err: ErrOS.Wrapf("test"),
expected: "local os: test",
},
{
name: "double wrapped string error",
err: ErrOS.Wrapf("test: %w", ErrInvalidPath),
expected: "local os: test: invalid path",
},
} {
t.Run(scenario.name, func(t *testing.T) {
require.Error(t, scenario.err)
require.Equal(t, scenario.expected, scenario.err.Error())
})
}
}

func TestUnwrap(t *testing.T) {
err := ErrOS.Wrap(ErrInvalidPath)
require.Equal(t, ErrInvalidPath, errors.Unwrap(err))
}

func TestErrorsIs(t *testing.T) {
err := ErrOS.Wrap(ErrInvalidPath.Wrap(ErrNotFound))
require.True(t, errors.Is(err, ErrOS))
require.True(t, errors.Is(err, ErrInvalidPath))
require.True(t, errors.Is(err, ErrNotFound))
require.False(t, errors.Is(err, ErrNotConnected))
}

type testErr struct {
msg string
}

func (t *testErr) Error() string {
return "foo " + t.msg
}

func TestErrorsAs(t *testing.T) {
err := ErrOS.Wrap(ErrInvalidPath.Wrap(&testErr{msg: "test"}))
var cmp *testErr
require.True(t, errors.As(err, &cmp))
require.Equal(t, "local os: invalid path: foo test", err.Error())
if errors.As(err, &cmp) {
require.Equal(t, "foo test", cmp.Error())
}
}
61 changes: 61 additions & 0 deletions errstring/errstring.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Package errstring defines a simple error struct
package errstring

import (
"fmt"
)

// Error is the base type for rig errors
type Error struct {
msg string
}

// Error implements the error interface
func (e *Error) Error() string {
return e.msg
}

func (e *Error) Unwrap() error {
return nil
}

// New creates a new error
func New(msg string) *Error {
return &Error{msg}
}

// Wrap wraps another error with this error
func (e *Error) Wrap(errB error) error {
return &wrappedError{
errA: e,
errB: errB,
}
}

// Wrapf is a shortcut for Wrap(fmt.Errorf("...", ...))
func (e *Error) Wrapf(msg string, args ...any) error {
return &wrappedError{
errA: e,
errB: fmt.Errorf(msg, args...), //nolint:goerr113
}
}

type wrappedError struct {
errA error
errB error
}

func (e *wrappedError) Error() string {
return e.errA.Error() + ": " + e.errB.Error()
}

func (e *wrappedError) Is(err error) bool {
if err == nil {
return false
}
return e.errA == err //nolint:goerr113
}

func (e *wrappedError) Unwrap() error {
return e.errB
}
8 changes: 8 additions & 0 deletions exec/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package exec

import "github.com/k0sproject/rig/errstring"

var (
ErrRemote = errstring.New("remote exec error") // ErrRemote is returned when an action fails on remote host
ErrSudo = errstring.New("sudo error") // ErrSudo is returned when wrapping a command with sudo fails
)
2 changes: 1 addition & 1 deletion exec/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (o *Options) Command(cmd string) (string, error) {

out, err := o.host.Sudo(cmd)
if err != nil {
return "", fmt.Errorf("failed to sudo: %w", err)
return "", ErrSudo.Wrap(err)
}
return out, nil
}
Expand Down
6 changes: 3 additions & 3 deletions localhost.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (c *Localhost) Exec(cmd string, opts ...exec.Option) error {
func (c *Localhost) command(cmd string, o *exec.Options) (*osexec.Cmd, error) {
cmd, err := o.Command(cmd)
if err != nil {
return nil, fmt.Errorf("failed to build command: %w", err)
return nil, fmt.Errorf("build command: %w", err)
}

if c.IsWindows() {
Expand All @@ -148,7 +148,7 @@ func (c *Localhost) Upload(src, dst string, opts ...exec.Option) error {

inFile, err := os.Open(src)
if err != nil {
return fmt.Errorf("failed to open source file %s: %w", src, err)
return ErrInvalidPath.Wrapf("failed to open local file %s: %w", src, err)
}
defer inFile.Close()

Expand All @@ -159,7 +159,7 @@ func (c *Localhost) Upload(src, dst string, opts ...exec.Option) error {
defer out.Close()
_, err = io.Copy(out, inFile)
if err != nil {
return fmt.Errorf("failed to copy file %s to %s: %w", src, dst, err)
return fmt.Errorf("failed to copy local file %s to remote %s: %w", src, dst, err)
}
return nil
}
Expand Down
Loading

0 comments on commit 81e018d

Please sign in to comment.