Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add waiter for object #777

Merged
merged 2 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions pkg/wait/wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright 2024 Nutanix. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package wait

import (
"context"
"fmt"
"time"

apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/util/wait"
"sigs.k8s.io/controller-runtime/pkg/client"
)

// CheckFailedError is used to determine whether the wait failed because wraps an error returned by a failed check.
type CheckFailedError struct {
cause error
}

func (e *CheckFailedError) Error() string {
return fmt.Sprintf("check failed: %s", e.cause)
}

func (e *CheckFailedError) Is(target error) bool {
_, ok := target.(*CheckFailedError)
return ok
}

func (e *CheckFailedError) Unwrap() error {
return e.cause
}

type ForObjectInput[T client.Object] struct {
Reader client.Reader
Target T
Check func(ctx context.Context, obj T) (bool, error)
Interval time.Duration
Timeout time.Duration
}

func ForObject[T client.Object](
ctx context.Context,
input ForObjectInput[T],
) error {
key := client.ObjectKeyFromObject(input.Target)

var getErr error
waitErr := wait.PollUntilContextTimeout(
ctx,
input.Interval,
input.Timeout,
true,
func(checkCtx context.Context) (bool, error) {
if getErr = input.Reader.Get(checkCtx, key, input.Target); getErr != nil {
if apierrors.IsNotFound(getErr) {
return false, nil
}
return false, getErr
}

if ok, err := input.Check(checkCtx, input.Target); err != nil {
return false, &CheckFailedError{cause: err}
} else {
// Retry if check fails.
return ok, nil
}
})

if wait.Interrupted(waitErr) {
if getErr != nil {
return fmt.Errorf("%w; last get error: %w", waitErr, getErr)
}
return fmt.Errorf("%w: check never passed", waitErr)
}
// waitErr is a CheckFailedError
return waitErr
}
177 changes: 177 additions & 0 deletions pkg/wait/wait_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Copyright 2024 Nutanix. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package wait

import (
"context"
"errors"
"fmt"
"testing"
"time"

corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/wait"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

var errBrokenReader = errors.New("broken")

type brokenReader struct{}

func (r *brokenReader) Get(
ctx context.Context,
key client.ObjectKey,
obj client.Object,
opts ...client.GetOption,
) error {
return errBrokenReader
}

func (r *brokenReader) List(
ctx context.Context,
list client.ObjectList,
opts ...client.ListOption,
) error {
return errBrokenReader
}

var _ client.Reader = &brokenReader{}

func TestWait(t *testing.T) {
tests := []struct {
name string
// We use the corev1.Namespace concrete type for the test, because we want to
// verify behavior for a concrete type, and because the Wait function is
// generic, and will behave identically for all concrete types.
input ForObjectInput[*corev1.Namespace]
errCheck func(error) bool
}{
{
name: "time out while get does not find object; report get error",
input: ForObjectInput[*corev1.Namespace]{
Reader: fake.NewFakeClient(),
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
return true, nil
},
Interval: time.Nanosecond,
Timeout: time.Millisecond,
Target: &corev1.Namespace{
TypeMeta: v1.TypeMeta{
Kind: "Namespace",
APIVersion: "v1",
},
ObjectMeta: v1.ObjectMeta{
Name: "example",
},
},
},
errCheck: func(err error) bool {
return wait.Interrupted(err) &&
apierrors.IsNotFound(err)
},
},
{
name: "return immediately when get fails; report get error",
input: ForObjectInput[*corev1.Namespace]{
Reader: &brokenReader{},
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
return true, nil
},
Interval: time.Nanosecond,
Timeout: time.Millisecond,
Target: &corev1.Namespace{
TypeMeta: v1.TypeMeta{
Kind: "Namespace",
APIVersion: "v1",
},
ObjectMeta: v1.ObjectMeta{
Name: "example",
},
},
},
errCheck: func(err error) bool {
return !wait.Interrupted(err) &&
!apierrors.IsNotFound(err) &&
errors.Is(err, errBrokenReader)
},
},
{
name: "time out while check returns false; no check error to report",
input: ForObjectInput[*corev1.Namespace]{
Reader: fake.NewFakeClient(
&corev1.Namespace{
TypeMeta: v1.TypeMeta{
Kind: "Namespace",
APIVersion: "v1",
},
ObjectMeta: v1.ObjectMeta{
Name: "example",
},
},
),
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
return false, nil
},
Interval: time.Nanosecond,
Timeout: time.Millisecond,
Target: &corev1.Namespace{
TypeMeta: v1.TypeMeta{
Kind: "Namespace",
APIVersion: "v1",
},
ObjectMeta: v1.ObjectMeta{
Name: "example",
},
},
},
errCheck: wait.Interrupted,
},
{
name: "return immediately when check returns an error; report the error",
input: ForObjectInput[*corev1.Namespace]{
Reader: fake.NewFakeClient(
&corev1.Namespace{
TypeMeta: v1.TypeMeta{
Kind: "Namespace",
APIVersion: "v1",
},
ObjectMeta: v1.ObjectMeta{
Name: "example",
},
},
),
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
return false, fmt.Errorf("condition failed")
},
Interval: time.Nanosecond,
Timeout: time.Millisecond, Target: &corev1.Namespace{
TypeMeta: v1.TypeMeta{
Kind: "Namespace",
APIVersion: "v1",
},
ObjectMeta: v1.ObjectMeta{
Name: "example",
},
},
},
errCheck: func(err error) bool {
return errors.Is(err, &CheckFailedError{}) &&
!wait.Interrupted(err)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ForObject(
context.Background(),
tt.input,
)
if !tt.errCheck(err) {
t.Errorf("error did not pass check: %s", err)
}
})
}
}
Loading