diff --git a/pkg/wait/wait.go b/pkg/wait/wait.go new file mode 100644 index 000000000..674862ab7 --- /dev/null +++ b/pkg/wait/wait.go @@ -0,0 +1,77 @@ +// Copyright 2024 Nutanix. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +package wait + +import ( + "context" + "fmt" + "time" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/util/wait" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// CheckFailedError is used to determine whether the wait failed because wraps an error returned by a failed check. +type CheckFailedError struct { + cause error +} + +func (e *CheckFailedError) Error() string { + return fmt.Sprintf("check failed: %s", e.cause) +} + +func (e *CheckFailedError) Is(target error) bool { + _, ok := target.(*CheckFailedError) + return ok +} + +func (e *CheckFailedError) Unwrap() error { + return e.cause +} + +type ForObjectInput[T client.Object] struct { + Reader client.Reader + Target T + Check func(ctx context.Context, obj T) (bool, error) + Interval time.Duration + Timeout time.Duration +} + +func ForObject[T client.Object]( + ctx context.Context, + input ForObjectInput[T], +) error { + key := client.ObjectKeyFromObject(input.Target) + + var getErr error + waitErr := wait.PollUntilContextTimeout( + ctx, + input.Interval, + input.Timeout, + true, + func(checkCtx context.Context) (bool, error) { + if getErr = input.Reader.Get(checkCtx, key, input.Target); getErr != nil { + if apierrors.IsNotFound(getErr) { + return false, nil + } + return false, getErr + } + + if ok, err := input.Check(checkCtx, input.Target); err != nil { + return false, &CheckFailedError{cause: err} + } else { + // Retry if check fails. + return ok, nil + } + }) + + if wait.Interrupted(waitErr) { + if getErr != nil { + return fmt.Errorf("%w; last get error: %w", waitErr, getErr) + } + return fmt.Errorf("%w: check never passed", waitErr) + } + // waitErr is a CheckFailedError + return waitErr +} diff --git a/pkg/wait/wait_test.go b/pkg/wait/wait_test.go new file mode 100644 index 000000000..e72a9e56a --- /dev/null +++ b/pkg/wait/wait_test.go @@ -0,0 +1,177 @@ +// Copyright 2024 Nutanix. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +package wait + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/wait" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +var brokenReaderError = errors.New("broken") + +type brokenReader struct{} + +func (r *brokenReader) Get( + ctx context.Context, + key client.ObjectKey, + obj client.Object, + opts ...client.GetOption, +) error { + return brokenReaderError +} + +func (r *brokenReader) List( + ctx context.Context, + list client.ObjectList, + opts ...client.ListOption, +) error { + return brokenReaderError +} + +var _ client.Reader = &brokenReader{} + +func TestWait(t *testing.T) { + tests := []struct { + name string + // We use the corev1.Namespace concrete type for the test, because we want to + // verify behavior for a concrete type, and because the Wait function is + // generic, and will behave identically for all concrete types. + input ForObjectInput[*corev1.Namespace] + errCheck func(error) bool + }{ + { + name: "time out while get does not find object; report get error", + input: ForObjectInput[*corev1.Namespace]{ + Reader: fake.NewFakeClient(), + Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) { + return true, nil + }, + Interval: time.Nanosecond, + Timeout: time.Millisecond, + Target: &corev1.Namespace{ + TypeMeta: v1.TypeMeta{ + Kind: "Namespace", + APIVersion: "v1", + }, + ObjectMeta: v1.ObjectMeta{ + Name: "example", + }, + }, + }, + errCheck: func(err error) bool { + return wait.Interrupted(err) && + apierrors.IsNotFound(err) + }, + }, + { + name: "return immediately when get fails; report get error", + input: ForObjectInput[*corev1.Namespace]{ + Reader: &brokenReader{}, + Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) { + return true, nil + }, + Interval: time.Nanosecond, + Timeout: time.Millisecond, + Target: &corev1.Namespace{ + TypeMeta: v1.TypeMeta{ + Kind: "Namespace", + APIVersion: "v1", + }, + ObjectMeta: v1.ObjectMeta{ + Name: "example", + }, + }, + }, + errCheck: func(err error) bool { + return !wait.Interrupted(err) && + !apierrors.IsNotFound(err) && + errors.Is(err, brokenReaderError) + }, + }, + { + name: "time out while check returns false; no check error to report", + input: ForObjectInput[*corev1.Namespace]{ + Reader: fake.NewFakeClient( + &corev1.Namespace{ + TypeMeta: v1.TypeMeta{ + Kind: "Namespace", + APIVersion: "v1", + }, + ObjectMeta: v1.ObjectMeta{ + Name: "example", + }, + }, + ), + Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) { + return false, nil + }, + Interval: time.Nanosecond, + Timeout: time.Millisecond, + Target: &corev1.Namespace{ + TypeMeta: v1.TypeMeta{ + Kind: "Namespace", + APIVersion: "v1", + }, + ObjectMeta: v1.ObjectMeta{ + Name: "example", + }, + }, + }, + errCheck: wait.Interrupted, + }, + { + name: "return immediately when check returns an error; report the error", + input: ForObjectInput[*corev1.Namespace]{ + Reader: fake.NewFakeClient( + &corev1.Namespace{ + TypeMeta: v1.TypeMeta{ + Kind: "Namespace", + APIVersion: "v1", + }, + ObjectMeta: v1.ObjectMeta{ + Name: "example", + }, + }, + ), + Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) { + return false, fmt.Errorf("condition failed") + }, + Interval: time.Nanosecond, + Timeout: time.Millisecond, Target: &corev1.Namespace{ + TypeMeta: v1.TypeMeta{ + Kind: "Namespace", + APIVersion: "v1", + }, + ObjectMeta: v1.ObjectMeta{ + Name: "example", + }, + }, + }, + errCheck: func(err error) bool { + return errors.Is(err, &CheckFailedError{}) && + !wait.Interrupted(err) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ForObject( + context.Background(), + tt.input, + ) + if !tt.errCheck(err) { + t.Errorf("error did not pass check: %s", err) + } + }) + } +}