Skip to content

Commit b4e4c80

Browse files
EtiennePerotgvisor-bot
authored andcommittedDec 10, 2024·
Internal change (diffbased).
PiperOrigin-RevId: 704502884
1 parent a55b3b2 commit b4e4c80

File tree

10 files changed

+304
-122
lines changed

10 files changed

+304
-122
lines changed
 

‎test/kubernetes/benchmarks/httpbench/httpbench.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,17 +317,19 @@ func getMeasurements(data string, onlyReport []MetricType, wantPercentiles []int
317317
return false
318318
}
319319
var metricValues []benchmetric.MetricValue
320-
var totalRequests int
320+
totalRequests := 0
321+
totalRequestsFound := false
321322
for _, line := range strings.Split(data, "\n") {
322323
if match := wrk2TotalRequestsRe.FindStringSubmatch(line); match != nil {
323324
gotRequests, err := strconv.ParseInt(strings.ReplaceAll(match[1], ",", ""), 10, 64)
324325
if err != nil {
325326
return 0, nil, fmt.Errorf("failed to parse %q from line %q: %v", match[1], line, err)
326327
}
327-
if totalRequests != 0 {
328+
if totalRequestsFound {
328329
return 0, nil, fmt.Errorf("found multiple lines matching 'total requests' regex: %d vs %d (%q)", totalRequests, gotRequests, line)
329330
}
330331
totalRequests = int(gotRequests)
332+
totalRequestsFound = true
331333
continue
332334
}
333335
if match := wrk2LatencyPercentileRE.FindStringSubmatch(line); match != nil {
@@ -375,7 +377,7 @@ func getMeasurements(data string, onlyReport []MetricType, wantPercentiles []int
375377
continue
376378
}
377379
}
378-
if totalRequests == 0 {
380+
if !totalRequestsFound {
379381
return 0, nil, fmt.Errorf("could not find total requests in output: %q", data)
380382
}
381383
return totalRequests, metricValues, nil

‎test/kubernetes/benchmarks/nginx.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ var (
4848
// The test expects that it contains the files to be served at /local,
4949
// and will serve files out of `nginxServingDir`.
5050
nginxCommand = []string{"nginx", "-c", "/etc/nginx/nginx.conf"}
51-
nginxDocKibibytes = []int{1, 10, 100, 10240}
52-
threads = []int{1, 8, 64, 1000}
53-
targetQPS = []int{1, 8, 64, httpbench.InfiniteQPS}
51+
nginxDocKibibytes = []int{1, 10240}
52+
threads = []int{1, 8, 1000}
53+
targetQPS = []int{1, 64, httpbench.InfiniteQPS}
5454
wantPercentiles = []int{50, 95, 99}
5555
)
5656

‎test/kubernetes/benchmarks/postgresql.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ const (
4646
)
4747

4848
var (
49-
numConnections = []int{1, 2, 6, 16, 32, 64}
49+
numConnections = []int{1, 2, 12, 64}
5050
)
5151

5252
// BenchmarkPostgresPGBench runs a PostgreSQL pgbench test.

‎test/kubernetes/benchmarks/redis.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ const (
4949
)
5050

5151
var (
52-
numConnections = []int{1, 2, 4, 8, 16, 32}
52+
numConnections = []int{1, 4, 32}
5353
latencyPercentiles = []int{50, 95, 99}
54-
operations = []string{"SET", "GET", "MSET", "LPUSH", "LRANGE_500"}
54+
operations = []string{"GET", "MSET", "LRANGE_500"}
5555
)
5656

5757
// BenchmarkRedis runs the Redis performance benchmark using redis-benchmark.

‎test/kubernetes/benchmarks/stablediffusion.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import (
3434

3535
const (
3636
// Container image for Stable Diffusion XL.
37-
stableDiffusionImage = k8s.ImageRepoPrefix + "gpu/stable-diffusion-xl"
37+
stableDiffusionImage = k8s.ImageRepoPrefix + "gpu/stable-diffusion-xl:latest"
3838
)
3939

4040
// kubernetesPodRunner implements `stablediffusion.ContainerRunner`.
@@ -171,7 +171,7 @@ func RunStableDiffusionXL(ctx context.Context, t *testing.T, k8sCtx k8sctx.Kuber
171171
t.Skipf("refiner failed in previous benchmark; skipping benchmark with refiner")
172172
}
173173
}
174-
testCtx, testCancel := context.WithTimeout(ctx, 15*time.Minute)
174+
testCtx, testCancel := context.WithTimeout(ctx, 50*time.Minute)
175175
defer testCancel()
176176
prompt := &stablediffusion.XLPrompt{
177177
Query: test.query,

‎test/kubernetes/benchmarks/wordpress.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ const (
5252
)
5353

5454
var (
55-
threads = []int{1, 8, 64, 1000}
56-
targetQPS = []int{1, 8, 64, httpbench.InfiniteQPS}
55+
threads = []int{1, 8, 1000}
56+
targetQPS = []int{1, 64, httpbench.InfiniteQPS}
5757
wantPercentiles = []int{50, 95, 99}
5858
)
5959

‎test/kubernetes/testcluster/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package(
88
go_library(
99
name = "testcluster",
1010
srcs = [
11+
"client.go",
1112
"objects.go",
1213
"testcluster.go",
1314
],
@@ -16,6 +17,7 @@ go_library(
1617
],
1718
deps = [
1819
"//pkg/log",
20+
"//pkg/rand",
1921
"//pkg/sync",
2022
"//test/kubernetes:test_range_config_go_proto",
2123
"@io_k8s_api//apps/v1:go_default_library",

‎test/kubernetes/testcluster/client.go

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
// Copyright 2024 The gVisor Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package testcluster
16+
17+
import (
18+
"context"
19+
"encoding/hex"
20+
"fmt"
21+
"io"
22+
"time"
23+
24+
"gvisor.dev/gvisor/pkg/log"
25+
"gvisor.dev/gvisor/pkg/rand"
26+
"k8s.io/client-go/kubernetes"
27+
)
28+
29+
// KubernetesReq is a function that performs a request with a Kubernetes
30+
// client.
31+
type KubernetesReq func(context.Context, kubernetes.Interface) error
32+
33+
// KubernetesClient is an interface that wraps Kubernetes requests.
34+
type KubernetesClient interface {
35+
// Do performs a request with a Kubernetes client.
36+
Do(context.Context, KubernetesReq) error
37+
}
38+
39+
// simpleClient is a KubernetesClient that wraps a simple Kubernetes client.
40+
// The `Do` function simply calls the function with the given `client`.
41+
type simpleClient struct {
42+
client kubernetes.Interface
43+
}
44+
45+
// Do implements `KubernetesClient.Do`.
46+
func (sc *simpleClient) Do(ctx context.Context, fn KubernetesReq) error {
47+
return fn(ctx, sc.client)
48+
}
49+
50+
// retryableClient is a KubernetesClient that can retry requests by creating
51+
// *new instances* of Kubernetes clients, rather than just retrying requests.
52+
type retryableClient struct {
53+
// client is a Kubernetes client factory, used to create new instances of
54+
// Kubernetes clients and to determine whether a request should be retried.
55+
client UnstableClient
56+
57+
// clientCh is a channel used to share Kubernetes clients between multiple
58+
// requests.
59+
clientCh chan kubernetes.Interface
60+
}
61+
62+
// UnstableClient is a Kubernetes client factory that can create new instances
63+
// of Kubernetes clients and determine whether a request should be retried.
64+
type UnstableClient interface {
65+
// Client creates a new instance of a Kubernetes client.
66+
// This function may also block (in a context-respecting manner)
67+
// in order to implement backoff between Kubernetes client creation
68+
// attempts.
69+
Client(context.Context) (kubernetes.Interface, error)
70+
71+
// RetryError returns whether the given error should be retried.
72+
// numAttempt is the number of attempts made so far.
73+
// This function may also block (in a context-respecting manner)
74+
// in order to implement backoff between request retries.
75+
RetryError(ctx context.Context, err error, numAttempt int) bool
76+
}
77+
78+
// NewRetryableClient creates a new retryable Kubernetes client.
79+
// It takes an `UnstableClient` as input, which is used to create new
80+
// instances of Kubernetes clients as needed, and to determine whether
81+
// a request should be retried.
82+
// This can be safely used concurrently, in which case additional
83+
// Kubernetes clients will be created as needed, and reused when
84+
// possible (but never garbage-collected, unless they start emitting
85+
// retriable errors).
86+
// It will immediately create an initial Kubernetes client from the
87+
// `UnstableClient` as the initial client to use.
88+
func NewRetryableClient(ctx context.Context, client UnstableClient) (KubernetesClient, error) {
89+
initialClient, err := client.Client(ctx)
90+
if err != nil {
91+
return nil, fmt.Errorf("cannot get initial client: %w", err)
92+
}
93+
clientCh := make(chan kubernetes.Interface, 128)
94+
clientCh <- initialClient
95+
return &retryableClient{client: client, clientCh: clientCh}, nil
96+
}
97+
98+
// getClient returns a Kubernetes client.
99+
// It will either return the client from the clientCh, or create a new one
100+
// if none are available.
101+
func (rc *retryableClient) getClient(ctx context.Context) (kubernetes.Interface, error) {
102+
select {
103+
case client := <-rc.clientCh:
104+
return client, nil
105+
default:
106+
client, err := rc.client.Client(ctx)
107+
if err != nil {
108+
return nil, fmt.Errorf("cannot get client: %w", err)
109+
}
110+
return client, nil
111+
}
112+
}
113+
114+
// putClient puts a Kubernetes client back into the `clientCh`.
115+
func (rc *retryableClient) putClient(client kubernetes.Interface) {
116+
select {
117+
case rc.clientCh <- client:
118+
default:
119+
// If full, just spawn a goroutine to put it back when possible.
120+
go func() { rc.clientCh <- client }()
121+
}
122+
}
123+
124+
// Do implements `KubernetesClient.Do`.
125+
// It retries the request if the error is retryable.
126+
func (rc *retryableClient) Do(ctx context.Context, fn KubernetesReq) error {
127+
client, err := rc.getClient(ctx)
128+
if err != nil {
129+
return fmt.Errorf("cannot get client: %w", err)
130+
}
131+
if err = fn(ctx, client); err == nil || !rc.client.RetryError(ctx, err, 0) { // Happy path.
132+
rc.putClient(client)
133+
return err
134+
}
135+
136+
// We generate a random ID here to distinguish between multiple retriable
137+
// operations in the logs.
138+
var operationIDBytes [8]byte
139+
if _, err := io.ReadFull(rand.Reader, operationIDBytes[:]); err != nil {
140+
return fmt.Errorf("cannot read random bytes: %w", err)
141+
}
142+
operationID := hex.EncodeToString(operationIDBytes[:])
143+
144+
logger := log.BasicRateLimitedLogger(30 * time.Second)
145+
deadline, hasDeadline := ctx.Deadline()
146+
if hasDeadline {
147+
logger.Infof("Retryable operation [%s] @ %s failed on initial attempt with retryable error (%v); retrying until %v...", operationID, time.Now().Format(time.TimeOnly), err, deadline)
148+
} else {
149+
logger.Infof("Retryable operation [%s] @ %s failed on initial attempt with retryable error (%v); retrying...", operationID, time.Now().Format(time.TimeOnly), err)
150+
}
151+
lastErr := err
152+
numAttempt := 1
153+
for ctx.Err() == nil {
154+
numAttempt++
155+
client, err := rc.getClient(ctx)
156+
if err != nil {
157+
return fmt.Errorf("cannot get client: %w", err)
158+
}
159+
if err = fn(ctx, client); err == nil || !rc.client.RetryError(ctx, err, numAttempt) {
160+
// We don't use `logger` here because we want to make sure it is logged
161+
// so that the logs reflect that the operation succeeded upon a retry.
162+
// Otherwise the logs can be confusing because it may seem that we are
163+
// still in the retry loop.
164+
if err == nil {
165+
log.Infof("Retryable operation [%s] @ %s succeeded on attempt %d.", operationID, time.Now().Format(time.TimeOnly), numAttempt)
166+
} else {
167+
log.Infof("Retryable operation [%s] @ %s attempt %d returned non-retryable error: %v.", operationID, time.Now().Format(time.TimeOnly), numAttempt, numAttempt, err)
168+
}
169+
rc.putClient(client)
170+
return err
171+
}
172+
logger.Infof("Retryable operation [%s] @ %s failed on attempt %d (retryable error: %v); will retry again...", operationID, time.Now().Format(time.TimeOnly), numAttempt, err, deadline)
173+
lastErr = err
174+
}
175+
log.Infof("Retryable operation [%s] @ %s failed after %d attempts with retryable error (%v) but context was cancelled (%v); bailing out.", operationID, time.Now().Format(time.TimeOnly), numAttempt, lastErr)
176+
return lastErr
177+
}
178+
179+
// request wraps a function that takes a KubernetesClient and returns a value of
180+
// type T. It is useful for functions that return more than just an error,
181+
// e.g. lookup functions that return a pod info or other Kubernetes resources.
182+
func request[T any](ctx context.Context, client KubernetesClient, fn func(context.Context, kubernetes.Interface) (T, error)) (T, error) {
183+
var result T
184+
err := client.Do(ctx, func(ctx context.Context, client kubernetes.Interface) error {
185+
var err error
186+
result, err = fn(ctx, client)
187+
return err
188+
})
189+
return result, err
190+
}

‎test/kubernetes/testcluster/objects.go

Lines changed: 15 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
"errors"
2020
"fmt"
2121
"reflect"
22-
"strconv"
2322

2423
cspb "google.golang.org/genproto/googleapis/container/v1"
2524
"google.golang.org/protobuf/proto"
@@ -181,16 +180,14 @@ type RuntimeType string
181180

182181
// List of known runtime types.
183182
const (
184-
RuntimeTypeGVisor = RuntimeType("gvisor")
185-
RuntimeTypeUnsandboxed = RuntimeType("runc")
186-
RuntimeTypeGVisorNvidia = RuntimeType("gvisor-nvidia")
187-
RuntimeTypeGVisorTPU = RuntimeType("gvisor-tpu")
188-
RuntimeTypeUnsandboxedNvidia = RuntimeType("runc-nvidia")
189-
RuntimeTypeUnsandboxedTPU = RuntimeType("runc-tpu")
183+
RuntimeTypeGVisor = RuntimeType("gvisor")
184+
RuntimeTypeUnsandboxed = RuntimeType("runc")
185+
RuntimeTypeGVisorTPU = RuntimeType("gvisor-tpu")
186+
RuntimeTypeUnsandboxedTPU = RuntimeType("runc-tpu")
190187
)
191188

192189
// ApplyNodepool modifies the nodepool to configure it to use the runtime.
193-
func (t RuntimeType) ApplyNodepool(nodepool *cspb.NodePool, accelType AcceleratorType, accelShape string, accelRes string) {
190+
func (t RuntimeType) ApplyNodepool(nodepool *cspb.NodePool) {
194191
if nodepool.GetConfig().GetLabels() == nil {
195192
nodepool.GetConfig().Labels = map[string]string{}
196193
}
@@ -204,81 +201,27 @@ func (t RuntimeType) ApplyNodepool(nodepool *cspb.NodePool, accelType Accelerato
204201
case RuntimeTypeUnsandboxed:
205202
nodepool.GetConfig().Labels[NodepoolRuntimeKey] = string(RuntimeTypeUnsandboxed)
206203
// Do nothing.
207-
case RuntimeTypeGVisorNvidia:
208-
nodepool.Config.SandboxConfig = &cspb.SandboxConfig{
209-
Type: cspb.SandboxConfig_GVISOR,
210-
}
211-
accelCount, err := strconv.Atoi(accelShape)
212-
if err != nil {
213-
panic(fmt.Sprintf("GPU count must be a valid number, got %v", accelShape))
214-
}
215-
if accelCount == 0 {
216-
panic("GPU count needs to be >=1")
217-
}
218-
nodepool.Config.MachineType = DefaultNvidiaMachineType
219-
nodepool.Config.Accelerators = []*cspb.AcceleratorConfig{
220-
{
221-
AcceleratorType: string(accelType),
222-
AcceleratorCount: int64(accelCount),
223-
},
224-
}
225-
nodepool.Config.Labels[NodepoolRuntimeKey] = string(RuntimeTypeGVisorNvidia)
226-
nodepool.Config.Labels[NodepoolNumAcceleratorsKey] = strconv.Itoa(accelCount)
227204
case RuntimeTypeGVisorTPU:
228-
nodepool.Config.MachineType = TPUAcceleratorMachineTypeMap[accelType]
229-
if err := setNodePlacementPolicyCompact(nodepool, accelShape); err != nil {
230-
panic(fmt.Sprintf("failed to set node placement policy: %v", err))
231-
}
232205
nodepool.Config.Labels[gvisorNodepoolKey] = gvisorRuntimeClass
233206
nodepool.Config.Labels[NodepoolRuntimeKey] = string(RuntimeTypeGVisorTPU)
234-
nodepool.Config.Labels[NodepoolTPUTopologyKey] = accelShape
235207
nodepool.Config.Taints = append(nodepool.Config.Taints, &cspb.NodeTaint{
236208
Key: gvisorNodepoolKey,
237209
Value: gvisorRuntimeClass,
238210
Effect: cspb.NodeTaint_NO_SCHEDULE,
239211
})
240-
case RuntimeTypeUnsandboxedNvidia:
241-
accelCount, err := strconv.Atoi(accelShape)
242-
if err != nil {
243-
panic(fmt.Sprintf("GPU count must be a valid number, got %v", accelShape))
244-
}
245-
if accelCount == 0 {
246-
panic("GPU count needs to be >=1")
247-
}
248-
nodepool.Config.MachineType = DefaultNvidiaMachineType
249-
nodepool.Config.Accelerators = []*cspb.AcceleratorConfig{
250-
{
251-
AcceleratorType: string(accelType),
252-
AcceleratorCount: int64(accelCount),
253-
},
254-
}
255-
nodepool.Config.Labels[NodepoolRuntimeKey] = string(RuntimeTypeUnsandboxedNvidia)
256-
nodepool.Config.Labels[NodepoolNumAcceleratorsKey] = strconv.Itoa(accelCount)
257212
case RuntimeTypeUnsandboxedTPU:
258-
nodepool.Config.MachineType = TPUAcceleratorMachineTypeMap[accelType]
259-
if err := setNodePlacementPolicyCompact(nodepool, accelShape); err != nil {
260-
panic(fmt.Sprintf("failed to set node placement policy: %v", err))
261-
}
262213
nodepool.Config.Labels[NodepoolRuntimeKey] = string(RuntimeTypeUnsandboxedTPU)
263-
nodepool.Config.Labels[NodepoolTPUTopologyKey] = accelShape
264214
default:
265215
panic(fmt.Sprintf("unsupported runtime %q", t))
266216
}
267-
if accelRes != "" {
268-
nodepool.Config.ReservationAffinity = &cspb.ReservationAffinity{
269-
ConsumeReservationType: cspb.ReservationAffinity_SPECIFIC_RESERVATION,
270-
Key: "compute.googleapis.com/reservation-name",
271-
Values: []string{accelRes},
272-
}
273-
}
274217
}
275218

276-
// setNodePlacementPolicyCompact sets the node placement policy to COMPACT
219+
// SetNodePlacementPolicyCompact sets the node placement policy to COMPACT
277220
// and with the given TPU topology.
278221
// This is done by reflection because the NodePool_PlacementPolicy proto
279222
// message isn't available in the latest exported version of the genproto API.
280223
// This is only used for TPU nodepools so not critical for most benchmarks.
281-
func setNodePlacementPolicyCompact(nodepool *cspb.NodePool, tpuTopology string) error {
224+
func SetNodePlacementPolicyCompact(nodepool *cspb.NodePool, tpuTopology string) error {
282225
placementPolicyField := reflect.ValueOf(nodepool).Elem().FieldByName("PlacementPolicy")
283226
if !placementPolicyField.IsValid() {
284227
return errors.New("nodepool does not have a PlacementPolicy field")
@@ -305,7 +248,15 @@ func (t RuntimeType) ApplyPodSpec(podSpec *v13.PodSpec) {
305248
case RuntimeTypeGVisor:
306249
podSpec.RuntimeClassName = proto.String(gvisorRuntimeClass)
307250
podSpec.NodeSelector[NodepoolRuntimeKey] = string(RuntimeTypeGVisor)
251+
podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{
252+
Key: "nvidia.com/gpu",
253+
Operator: v13.TolerationOpExists,
254+
})
308255
case RuntimeTypeUnsandboxed:
256+
podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{
257+
Key: "nvidia.com/gpu",
258+
Operator: v13.TolerationOpExists,
259+
})
309260
// Allow the pod to schedule on gVisor nodes as well.
310261
// This enables the use of `--test-nodepool-runtime=runc` to run
311262
// unsandboxed benchmarks on gVisor test clusters.
@@ -315,34 +266,13 @@ func (t RuntimeType) ApplyPodSpec(podSpec *v13.PodSpec) {
315266
Operator: v13.TolerationOpEqual,
316267
Value: gvisorRuntimeClass,
317268
})
318-
case RuntimeTypeGVisorNvidia:
319-
podSpec.RuntimeClassName = proto.String(gvisorRuntimeClass)
320-
podSpec.NodeSelector[NodepoolRuntimeKey] = string(RuntimeTypeGVisorNvidia)
321-
podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{
322-
Key: "nvidia.com/gpu",
323-
Operator: v13.TolerationOpExists,
324-
})
325269
case RuntimeTypeGVisorTPU:
326270
podSpec.RuntimeClassName = proto.String(gvisorRuntimeClass)
327271
podSpec.NodeSelector[NodepoolRuntimeKey] = string(RuntimeTypeGVisorTPU)
328272
podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{
329273
Key: "google.com/tpu",
330274
Operator: v13.TolerationOpExists,
331275
})
332-
case RuntimeTypeUnsandboxedNvidia:
333-
podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{
334-
Key: "nvidia.com/gpu",
335-
Operator: v13.TolerationOpExists,
336-
})
337-
// Allow the pod to schedule on gVisor nodes as well.
338-
// This enables the use of `--test-nodepool-runtime=runc-nvidia` to run
339-
// unsandboxed benchmarks on gVisor test clusters.
340-
podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{
341-
Effect: v13.TaintEffectNoSchedule,
342-
Key: gvisorNodepoolKey,
343-
Operator: v13.TolerationOpEqual,
344-
Value: gvisorRuntimeClass,
345-
})
346276
case RuntimeTypeUnsandboxedTPU:
347277
podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{
348278
Key: "google.com/tpu",

‎test/kubernetes/testcluster/testcluster.go

Lines changed: 82 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ const (
140140
// TestCluster wraps clusters with their individual ClientSets so that helper methods can be called.
141141
type TestCluster struct {
142142
clusterName string
143-
client kubernetes.Interface
143+
144+
client KubernetesClient
144145

145146
// testNodepoolRuntimeOverride, if set, overrides the runtime used for pods
146147
// running on the test nodepool. If unset, the test nodepool's default
@@ -209,6 +210,12 @@ func NewTestClusterFromProto(ctx context.Context, cluster *testpb.Cluster) (*Tes
209210

210211
// NewTestClusterFromClient returns a new TestCluster client with a given client.
211212
func NewTestClusterFromClient(clusterName string, client kubernetes.Interface) *TestCluster {
213+
return NewTestClusterFromKubernetesClient(clusterName, &simpleClient{client})
214+
}
215+
216+
// NewTestClusterFromKubernetesClient returns a new TestCluster client with a
217+
// given KubernetesClient.
218+
func NewTestClusterFromKubernetesClient(clusterName string, client KubernetesClient) *TestCluster {
212219
return &TestCluster{
213220
clusterName: clusterName,
214221
client: client,
@@ -248,17 +255,24 @@ func (t *TestCluster) OverrideTestNodepoolRuntime(testRuntime RuntimeType) {
248255

249256
// createNamespace creates a namespace.
250257
func (t *TestCluster) createNamespace(ctx context.Context, namespace *v13.Namespace) (*v13.Namespace, error) {
251-
return t.client.CoreV1().Namespaces().Create(ctx, namespace, v1.CreateOptions{})
258+
return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.Namespace, error) {
259+
return client.CoreV1().Namespaces().Create(ctx, namespace, v1.CreateOptions{})
260+
})
252261
}
253262

254263
// getNamespace returns the given namespace in the cluster if it exists.
255264
func (t *TestCluster) getNamespace(ctx context.Context, namespaceName string) (*v13.Namespace, error) {
256-
return t.client.CoreV1().Namespaces().Get(ctx, namespaceName, v1.GetOptions{})
265+
return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.Namespace, error) {
266+
return client.CoreV1().Namespaces().Get(ctx, namespaceName, v1.GetOptions{})
267+
})
257268
}
258269

259270
// deleteNamespace is a helper method to delete a namespace.
260271
func (t *TestCluster) deleteNamespace(ctx context.Context, namespaceName string) error {
261-
if err := t.client.CoreV1().Namespaces().Delete(ctx, namespaceName, v1.DeleteOptions{}); err != nil {
272+
err := t.client.Do(ctx, func(ctx context.Context, client kubernetes.Interface) error {
273+
return client.CoreV1().Namespaces().Delete(ctx, namespaceName, v1.DeleteOptions{})
274+
})
275+
if err != nil {
262276
return err
263277
}
264278
// Wait for the namespace to disappear or for the context to expire.
@@ -282,7 +296,9 @@ func (t *TestCluster) getNodePool(ctx context.Context, nodepoolType NodePoolType
282296
t.nodepoolsMu.Lock()
283297
defer t.nodepoolsMu.Unlock()
284298
if t.nodepools == nil {
285-
nodes, err := t.client.CoreV1().Nodes().List(ctx, v1.ListOptions{})
299+
nodes, err := request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.NodeList, error) {
300+
return client.CoreV1().Nodes().List(ctx, v1.ListOptions{})
301+
})
286302
if err != nil {
287303
return nil, fmt.Errorf("cannot list nodes: %w", err)
288304
}
@@ -363,30 +379,39 @@ func (t *TestCluster) HasGVisorTestRuntime(ctx context.Context) (bool, error) {
363379
if err != nil {
364380
return false, err
365381
}
366-
return testNodePool.runtime == RuntimeTypeGVisor || testNodePool.runtime == RuntimeTypeGVisorNvidia, nil
382+
return testNodePool.runtime == RuntimeTypeGVisor || testNodePool.runtime == RuntimeTypeGVisorTPU, nil
367383
}
368384

369385
// CreatePod is a helper to create a pod.
370386
func (t *TestCluster) CreatePod(ctx context.Context, pod *v13.Pod) (*v13.Pod, error) {
371387
if pod.GetObjectMeta().GetNamespace() == "" {
372388
pod.SetNamespace(NamespaceDefault)
373389
}
374-
return t.client.CoreV1().Pods(pod.GetNamespace()).Create(ctx, pod, v1.CreateOptions{})
390+
return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.Pod, error) {
391+
return client.CoreV1().Pods(pod.GetNamespace()).Create(ctx, pod, v1.CreateOptions{})
392+
})
375393
}
376394

377395
// GetPod is a helper method to Get a pod's metadata.
378396
func (t *TestCluster) GetPod(ctx context.Context, pod *v13.Pod) (*v13.Pod, error) {
379-
return t.client.CoreV1().Pods(pod.GetNamespace()).Get(ctx, pod.GetName(), v1.GetOptions{})
397+
return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.Pod, error) {
398+
return client.CoreV1().Pods(pod.GetNamespace()).Get(ctx, pod.GetName(), v1.GetOptions{})
399+
})
380400
}
381401

382402
// ListPods is a helper method to List pods in a cluster.
383403
func (t *TestCluster) ListPods(ctx context.Context, namespace string) (*v13.PodList, error) {
384-
return t.client.CoreV1().Pods(namespace).List(ctx, v1.ListOptions{})
404+
return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.PodList, error) {
405+
return client.CoreV1().Pods(namespace).List(ctx, v1.ListOptions{})
406+
})
385407
}
386408

387409
// DeletePod is a helper method to delete a pod.
388410
func (t *TestCluster) DeletePod(ctx context.Context, pod *v13.Pod) error {
389-
if err := t.client.CoreV1().Pods(pod.GetNamespace()).Delete(ctx, pod.GetName(), v1.DeleteOptions{}); err != nil {
411+
err := t.client.Do(ctx, func(ctx context.Context, client kubernetes.Interface) error {
412+
return client.CoreV1().Pods(pod.GetNamespace()).Delete(ctx, pod.GetName(), v1.DeleteOptions{})
413+
})
414+
if err != nil {
390415
return err
391416
}
392417
// Wait for the pod to disappear or for the context to expire.
@@ -406,7 +431,9 @@ func (t *TestCluster) DeletePod(ctx context.Context, pod *v13.Pod) error {
406431
// GetLogReader gets an io.ReadCloser from which logs can be read. It is the caller's
407432
// responsibility to close it.
408433
func (t *TestCluster) GetLogReader(ctx context.Context, pod *v13.Pod, opts v13.PodLogOptions) (io.ReadCloser, error) {
409-
return t.client.CoreV1().Pods(pod.GetNamespace()).GetLogs(pod.GetName(), &opts).Stream(ctx)
434+
return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (io.ReadCloser, error) {
435+
return client.CoreV1().Pods(pod.GetNamespace()).GetLogs(pod.GetName(), &opts).Stream(ctx)
436+
})
410437
}
411438

412439
// ReadPodLogs reads logs from a pod.
@@ -602,22 +629,36 @@ func (t *TestCluster) ContainerDurationSecondsByName(ctx context.Context, pod *v
602629

603630
// CreateService is a helper method to create a service in a cluster.
604631
func (t *TestCluster) CreateService(ctx context.Context, service *v13.Service) (*v13.Service, error) {
605-
return t.client.CoreV1().Services(service.GetNamespace()).Create(ctx, service, v1.CreateOptions{})
632+
return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.Service, error) {
633+
return client.CoreV1().Services(service.GetNamespace()).Create(ctx, service, v1.CreateOptions{})
634+
})
635+
}
636+
637+
// GetService is a helper method to get a service in a cluster.
638+
func (t *TestCluster) GetService(ctx context.Context, service *v13.Service) (*v13.Service, error) {
639+
return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.Service, error) {
640+
return client.CoreV1().Services(service.GetNamespace()).Get(ctx, service.GetName(), v1.GetOptions{})
641+
})
606642
}
607643

608644
// ListServices is a helper method to List services in a cluster.
609645
func (t *TestCluster) ListServices(ctx context.Context, namespace string) (*v13.ServiceList, error) {
610-
return t.client.CoreV1().Services(namespace).List(ctx, v1.ListOptions{})
646+
return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.ServiceList, error) {
647+
return client.CoreV1().Services(namespace).List(ctx, v1.ListOptions{})
648+
})
611649
}
612650

613651
// DeleteService is a helper to delete a given service.
614652
func (t *TestCluster) DeleteService(ctx context.Context, service *v13.Service) error {
615-
if err := t.client.CoreV1().Services(service.GetNamespace()).Delete(ctx, service.GetName(), v1.DeleteOptions{}); err != nil {
653+
err := t.client.Do(ctx, func(ctx context.Context, client kubernetes.Interface) error {
654+
return client.CoreV1().Services(service.GetNamespace()).Delete(ctx, service.GetName(), v1.DeleteOptions{})
655+
})
656+
if err != nil {
616657
return err
617658
}
618659
// Wait for the service to disappear or for the context to expire.
619660
for ctx.Err() == nil {
620-
if _, err := t.client.CoreV1().Services(service.GetNamespace()).Get(ctx, service.GetName(), v1.GetOptions{}); err != nil {
661+
if _, err := t.GetService(ctx, service); err != nil {
621662
return nil
622663
}
623664
select {
@@ -639,7 +680,7 @@ func (t *TestCluster) WaitForServiceReady(ctx context.Context, service *v13.Serv
639680
case <-ctx.Done():
640681
return fmt.Errorf("context expired waiting for service %q: %w (last: %v)", service.GetName(), ctx.Err(), lastService)
641682
case <-pollCh.C:
642-
s, err := t.client.CoreV1().Services(service.GetNamespace()).Get(ctx, service.GetName(), v1.GetOptions{})
683+
s, err := t.GetService(ctx, service)
643684
if err != nil {
644685
return fmt.Errorf("cannot look up service %q: %w", service.GetName(), err)
645686
}
@@ -662,25 +703,40 @@ func (t *TestCluster) CreatePersistentVolume(ctx context.Context, volume *v13.Pe
662703
if volume.GetObjectMeta().GetNamespace() == "" {
663704
volume.SetNamespace(NamespaceDefault)
664705
}
665-
return t.client.CoreV1().PersistentVolumeClaims(volume.GetNamespace()).Create(ctx, volume, v1.CreateOptions{})
706+
return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.PersistentVolumeClaim, error) {
707+
return client.CoreV1().PersistentVolumeClaims(volume.GetNamespace()).Create(ctx, volume, v1.CreateOptions{})
708+
})
666709
}
667710

668711
// DeletePersistentVolume deletes a persistent volume.
669712
func (t *TestCluster) DeletePersistentVolume(ctx context.Context, volume *v13.PersistentVolumeClaim) error {
670-
return t.client.CoreV1().PersistentVolumeClaims(volume.GetNamespace()).Delete(ctx, volume.GetName(), v1.DeleteOptions{})
713+
return t.client.Do(ctx, func(ctx context.Context, client kubernetes.Interface) error {
714+
return client.CoreV1().PersistentVolumeClaims(volume.GetNamespace()).Delete(ctx, volume.GetName(), v1.DeleteOptions{})
715+
})
671716
}
672717

673718
// CreateDaemonset creates a daemonset with default options.
674719
func (t *TestCluster) CreateDaemonset(ctx context.Context, ds *appsv1.DaemonSet) (*appsv1.DaemonSet, error) {
675720
if ds.GetObjectMeta().GetNamespace() == "" {
676721
ds.SetNamespace(NamespaceDefault)
677722
}
678-
return t.client.AppsV1().DaemonSets(ds.GetNamespace()).Create(ctx, ds, v1.CreateOptions{})
723+
return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*appsv1.DaemonSet, error) {
724+
return client.AppsV1().DaemonSets(ds.GetNamespace()).Create(ctx, ds, v1.CreateOptions{})
725+
})
726+
}
727+
728+
// GetDaemonset gets a daemonset.
729+
func (t *TestCluster) GetDaemonset(ctx context.Context, ds *appsv1.DaemonSet) (*appsv1.DaemonSet, error) {
730+
return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*appsv1.DaemonSet, error) {
731+
return client.AppsV1().DaemonSets(ds.GetNamespace()).Get(ctx, ds.GetName(), v1.GetOptions{})
732+
})
679733
}
680734

681735
// DeleteDaemonset deletes a daemonset from this cluster.
682736
func (t *TestCluster) DeleteDaemonset(ctx context.Context, ds *appsv1.DaemonSet) error {
683-
return t.client.AppsV1().DaemonSets(ds.GetNamespace()).Delete(ctx, ds.GetName(), v1.DeleteOptions{})
737+
return t.client.Do(ctx, func(ctx context.Context, client kubernetes.Interface) error {
738+
return client.AppsV1().DaemonSets(ds.GetNamespace()).Delete(ctx, ds.GetName(), v1.DeleteOptions{})
739+
})
684740
}
685741

686742
// GetPodsInDaemonSet returns the list of pods of the given DaemonSet.
@@ -689,7 +745,9 @@ func (t *TestCluster) GetPodsInDaemonSet(ctx context.Context, ds *appsv1.DaemonS
689745
if appLabel, found := ds.Spec.Template.Labels[k8sApp]; found {
690746
listOptions.LabelSelector = fmt.Sprintf("%s=%s", k8sApp, appLabel)
691747
}
692-
pods, err := t.client.CoreV1().Pods(ds.ObjectMeta.Namespace).List(ctx, listOptions)
748+
pods, err := request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.PodList, error) {
749+
return client.CoreV1().Pods(ds.ObjectMeta.Namespace).List(ctx, listOptions)
750+
})
693751
if err != nil {
694752
return nil, err
695753
}
@@ -709,7 +767,7 @@ func (t *TestCluster) WaitForDaemonset(ctx context.Context, ds *appsv1.DaemonSet
709767
defer pollCh.Stop()
710768
// Poll-based loop to wait for the DaemonSet to be ready.
711769
for {
712-
d, err := t.client.AppsV1().DaemonSets(ds.GetNamespace()).Get(ctx, ds.GetName(), v1.GetOptions{})
770+
d, err := t.GetDaemonset(ctx, ds)
713771
if err != nil {
714772
return fmt.Errorf("failed to get daemonset %q: %v", ds.GetName(), err)
715773
}
@@ -778,7 +836,7 @@ func (t *TestCluster) StreamDaemonSetLogs(ctx context.Context, ds *appsv1.Daemon
778836
if _, seen := nodesSeen[pod.Spec.NodeName]; seen {
779837
continue // Node already seen.
780838
}
781-
logReader, err := t.client.CoreV1().Pods(pod.GetNamespace()).GetLogs(pod.GetName(), &opts).Stream(ctx)
839+
logReader, err := t.GetLogReader(ctx, &pod, opts)
782840
if err != nil {
783841
// This can happen if the container hasn't run yet, for example
784842
// because other init containers that run earlier are still executing.
@@ -813,7 +871,7 @@ Outer:
813871
}
814872
break Outer
815873
case <-timeTicker.C:
816-
d, err := t.client.AppsV1().DaemonSets(ds.GetNamespace()).Get(ctx, ds.GetName(), v1.GetOptions{})
874+
d, err := t.GetDaemonset(ctx, ds)
817875
if err != nil {
818876
loopError = fmt.Errorf("failed to get DaemonSet: %v", err)
819877
break Outer

0 commit comments

Comments
 (0)
Please sign in to comment.