From a356189e4b8067ed3f98730ae6371e70d794f394 Mon Sep 17 00:00:00 2001 From: David Xia Date: Wed, 5 Feb 2025 14:31:17 -0500 Subject: [PATCH] [feat][kubectl-plugin] add `scale` command to scale a RayCluster's worker group. closes #110 ## Example Usage ```console $ kubectl ray scale cluster -h (base) Scale a Ray cluster's worker group. Usage: ray scale cluster (WORKERGROUP) (-c/--ray-cluster RAYCLUSTER) (-r/--replicas N) [flags] Examples: # Scale a Ray cluster's worker group to 3 replicas kubectl ray scale cluster my-workergroup --ray-cluster my-raycluster --replicas 3 $ kubectl ray scale default-group --ray-cluster NONEXISTENT --replicas 0 Error: failed to scale worker group default-group in Ray cluster NONEXISTENT in namespace default: rayclusters.ray.io "NONEXISTENT" not found $ kubectl ray scale DEADBEEF --ray-cluster dxia-test --replicas 1 Error: worker group DEADBEEF not found in Ray cluster dxia-test in namespace default. Available worker groups: default-group, another-group, yet-another-group $ kubectl ray scale default-group --ray-cluster dxia-test --replicas 3 Scaled worker group default-group in Ray cluster dxia-test in namespace default from 0 to 3 replicas $ kubectl ray scale default-group --ray-cluster dxia-test --replicas 1 Scaled worker group default-group in Ray cluster dxia-test in namespace default from 3 to 1 replicas $ kubectl ray scale default-group --ray-cluster dxia-test --replicas -1 Error: must specify -r/--replicas with a non-negative integer ``` Signed-off-by: David Xia --- kubectl-plugin/pkg/cmd/ray.go | 2 + kubectl-plugin/pkg/cmd/scale/scale.go | 26 ++ kubectl-plugin/pkg/cmd/scale/scale_cluster.go | 144 +++++++++++ .../pkg/cmd/scale/scale_cluster_test.go | 244 ++++++++++++++++++ 4 files changed, 416 insertions(+) create mode 100644 kubectl-plugin/pkg/cmd/scale/scale.go create mode 100644 kubectl-plugin/pkg/cmd/scale/scale_cluster.go create mode 100644 kubectl-plugin/pkg/cmd/scale/scale_cluster_test.go diff --git a/kubectl-plugin/pkg/cmd/ray.go b/kubectl-plugin/pkg/cmd/ray.go index dfae10b3d1e..649d36d7310 100644 --- a/kubectl-plugin/pkg/cmd/ray.go +++ b/kubectl-plugin/pkg/cmd/ray.go @@ -10,6 +10,7 @@ import ( "github.com/ray-project/kuberay/kubectl-plugin/pkg/cmd/get" "github.com/ray-project/kuberay/kubectl-plugin/pkg/cmd/job" "github.com/ray-project/kuberay/kubectl-plugin/pkg/cmd/log" + "github.com/ray-project/kuberay/kubectl-plugin/pkg/cmd/scale" "github.com/ray-project/kuberay/kubectl-plugin/pkg/cmd/session" "github.com/ray-project/kuberay/kubectl-plugin/pkg/cmd/version" ) @@ -35,6 +36,7 @@ func NewRayCommand(streams genericiooptions.IOStreams) *cobra.Command { cmd.AddCommand(version.NewVersionCommand(streams)) cmd.AddCommand(create.NewCreateCommand(streams)) cmd.AddCommand(kubectlraydelete.NewDeleteCommand(streams)) + cmd.AddCommand(scale.NewScaleCommand(streams)) return cmd } diff --git a/kubectl-plugin/pkg/cmd/scale/scale.go b/kubectl-plugin/pkg/cmd/scale/scale.go new file mode 100644 index 00000000000..9c080bc04c2 --- /dev/null +++ b/kubectl-plugin/pkg/cmd/scale/scale.go @@ -0,0 +1,26 @@ +package scale + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + "k8s.io/cli-runtime/pkg/genericclioptions" +) + +func NewScaleCommand(streams genericclioptions.IOStreams) *cobra.Command { + cmd := &cobra.Command{ + Use: "scale", + Short: "Scale a Ray resource", + SilenceUsage: true, + Run: func(cmd *cobra.Command, args []string) { + if len(args) > 0 { + fmt.Println(fmt.Errorf("unknown command(s) %q", strings.Join(args, " "))) + } + cmd.HelpFunc()(cmd, args) + }, + } + + cmd.AddCommand(NewScaleClusterCommand(streams)) + return cmd +} diff --git a/kubectl-plugin/pkg/cmd/scale/scale_cluster.go b/kubectl-plugin/pkg/cmd/scale/scale_cluster.go new file mode 100644 index 00000000000..5dfb9d70c9b --- /dev/null +++ b/kubectl-plugin/pkg/cmd/scale/scale_cluster.go @@ -0,0 +1,144 @@ +package scale + +import ( + "context" + "fmt" + "io" + "os" + "strings" + + "github.com/ray-project/kuberay/kubectl-plugin/pkg/util" + "github.com/ray-project/kuberay/kubectl-plugin/pkg/util/client" + "github.com/spf13/cobra" + "k8s.io/cli-runtime/pkg/genericclioptions" + "k8s.io/kubectl/pkg/util/templates" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + cmdutil "k8s.io/kubectl/pkg/cmd/util" +) + +type ScaleClusterOptions struct { + configFlags *genericclioptions.ConfigFlags + ioStreams *genericclioptions.IOStreams + replicas *int32 + workerGroup string + cluster string +} + +var ( + scaleLong = templates.LongDesc(` + Scale a Ray cluster's worker group. + `) + + scaleExample = templates.Examples(` + # Scale a Ray cluster's worker group to 3 replicas + kubectl ray scale cluster my-workergroup --ray-cluster my-raycluster --replicas 3 + `) +) + +func NewScaleClusterOptions(streams genericclioptions.IOStreams) *ScaleClusterOptions { + return &ScaleClusterOptions{ + configFlags: genericclioptions.NewConfigFlags(true), + ioStreams: &streams, + replicas: new(int32), + } +} + +func NewScaleClusterCommand(streams genericclioptions.IOStreams) *cobra.Command { + options := NewScaleClusterOptions(streams) + cmdFactory := cmdutil.NewFactory(options.configFlags) + + cmd := &cobra.Command{ + Use: "cluster (WORKERGROUP) (-c/--ray-cluster RAYCLUSTER) (-r/--replicas N)", + Short: "Scale a Ray cluster's worker group", + Long: scaleLong, + Example: scaleExample, + SilenceUsage: true, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := options.Complete(args); err != nil { + return err + } + if err := options.Validate(); err != nil { + return err + } + k8sClient, err := client.NewClient(cmdFactory) + if err != nil { + return fmt.Errorf("failed to create client: %w", err) + } + return options.Run(cmd.Context(), k8sClient, os.Stdout) + }, + } + + cmd.Flags().StringVarP(&options.cluster, "ray-cluster", "c", "", "Ray cluster of the worker group") + cobra.CheckErr(cmd.MarkFlagRequired("ray-cluster")) + cmd.Flags().Int32VarP(options.replicas, "replicas", "r", -1, "Desired number of replicas in worker group") + options.configFlags.AddFlags(cmd.Flags()) + return cmd +} + +func (options *ScaleClusterOptions) Complete(args []string) error { + if *options.configFlags.Namespace == "" { + *options.configFlags.Namespace = "default" + } + + options.workerGroup = args[0] + + return nil +} + +func (options *ScaleClusterOptions) Validate() error { + config, err := options.configFlags.ToRawKubeConfigLoader().RawConfig() + if err != nil { + return fmt.Errorf("error retrieving raw config: %w", err) + } + + if !util.HasKubectlContext(config, options.configFlags) { + return fmt.Errorf("no context is currently set, use %q or %q to select a new one", "--context", "kubectl config use-context ") + } + + if options.cluster == "" { + return fmt.Errorf("must specify -c/--ray-cluster") + } + + if options.replicas == nil || *options.replicas < 0 { + return fmt.Errorf("must specify -r/--replicas with a non-negative integer") + } + + return nil +} + +func (options *ScaleClusterOptions) Run(ctx context.Context, k8sClient client.Client, writer io.Writer) error { + cluster, err := k8sClient.RayClient().RayV1().RayClusters(*options.configFlags.Namespace).Get(ctx, options.cluster, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("failed to scale worker group %s in Ray cluster %s in namespace %s: %w", options.workerGroup, options.cluster, *options.configFlags.Namespace, err) + } + + // find the index of the worker group + var workerGroups []string + workerGroupIndex := -1 + for i, workerGroup := range cluster.Spec.WorkerGroupSpecs { + workerGroups = append(workerGroups, workerGroup.GroupName) + if workerGroup.GroupName == options.workerGroup { + workerGroupIndex = i + } + } + if workerGroupIndex == -1 { + return fmt.Errorf("worker group %s not found in Ray cluster %s in namespace %s. Available worker groups: %s", options.workerGroup, options.cluster, *options.configFlags.Namespace, strings.Join(workerGroups, ", ")) + } + + previousReplicas := *cluster.Spec.WorkerGroupSpecs[workerGroupIndex].Replicas + if previousReplicas == *options.replicas { + fmt.Fprintf(writer, "worker group %s in Ray cluster %s in namespace %s already has %d replicas. Skipping\n", options.workerGroup, options.cluster, *options.configFlags.Namespace, previousReplicas) + return nil + } + + cluster.Spec.WorkerGroupSpecs[workerGroupIndex].Replicas = options.replicas + _, err = k8sClient.RayClient().RayV1().RayClusters(*options.configFlags.Namespace).Update(ctx, cluster, metav1.UpdateOptions{}) + if err != nil { + return fmt.Errorf("failed to scale worker group %s in Ray cluster %s in namespace %s: %w", options.workerGroup, options.cluster, *options.configFlags.Namespace, err) + } + + fmt.Fprintf(writer, "Scaled worker group %s in Ray cluster %s in namespace %s from %d to %d replicas\n", options.workerGroup, options.cluster, *options.configFlags.Namespace, previousReplicas, *options.replicas) + return nil +} diff --git a/kubectl-plugin/pkg/cmd/scale/scale_cluster_test.go b/kubectl-plugin/pkg/cmd/scale/scale_cluster_test.go new file mode 100644 index 00000000000..6bf8640a18f --- /dev/null +++ b/kubectl-plugin/pkg/cmd/scale/scale_cluster_test.go @@ -0,0 +1,244 @@ +package scale + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/ray-project/kuberay/kubectl-plugin/pkg/util" + "github.com/ray-project/kuberay/kubectl-plugin/pkg/util/client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/cli-runtime/pkg/genericclioptions" + kubefake "k8s.io/client-go/kubernetes/fake" + "k8s.io/utils/ptr" + + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + rayClientFake "github.com/ray-project/kuberay/ray-operator/pkg/client/clientset/versioned/fake" +) + +func TestRayScaleClusterComplete(t *testing.T) { + tests := []struct { + name string + namespace string + expectedNamespace string + args []string + }{ + { + name: "namespace should be set to 'default' if not specified", + args: []string{"my-workergroup"}, + expectedNamespace: "default", + }, + { + name: "namespace and worker group should be set correctly", + args: []string{"my-workergroup"}, + namespace: "DEADBEEF", + expectedNamespace: "DEADBEEF", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + testStreams, _, _, _ := genericclioptions.NewTestIOStreams() + fakeScaleClusterOptions := NewScaleClusterOptions(testStreams) + if tc.namespace != "" { + fakeScaleClusterOptions.configFlags.Namespace = &tc.namespace + } + + err := fakeScaleClusterOptions.Complete([]string{"my-workergroup"}) + + require.NoError(t, err) + assert.Equal(t, tc.expectedNamespace, *fakeScaleClusterOptions.configFlags.Namespace) + assert.Equal(t, "my-workergroup", fakeScaleClusterOptions.workerGroup) + }) + } +} + +func TestRayScaleClusterValidate(t *testing.T) { + testStreams, _, _, _ := genericclioptions.NewTestIOStreams() + + kubeConfigWithCurrentContext, err := util.CreateTempKubeConfigFile(t, "test-context") + require.NoError(t, err) + + kubeConfigWithoutCurrentContext, err := util.CreateTempKubeConfigFile(t, "") + require.NoError(t, err) + + tests := []struct { + name string + opts *ScaleClusterOptions + expect string + expectError string + }{ + { + name: "should error when no context is set", + opts: &ScaleClusterOptions{ + configFlags: &genericclioptions.ConfigFlags{ + KubeConfig: &kubeConfigWithoutCurrentContext, + }, + ioStreams: &testStreams, + }, + expectError: "no context is currently set, use \"--context\" or \"kubectl config use-context \" to select a new one", + }, + { + name: "should error when no RayCluster is set", + opts: &ScaleClusterOptions{ + configFlags: &genericclioptions.ConfigFlags{ + KubeConfig: &kubeConfigWithCurrentContext, + }, + ioStreams: &testStreams, + }, + expectError: "must specify -c/--ray-cluster", + }, + { + name: "should error when no replicas are set", + opts: &ScaleClusterOptions{ + configFlags: &genericclioptions.ConfigFlags{ + KubeConfig: &kubeConfigWithCurrentContext, + }, + ioStreams: &testStreams, + cluster: "test-cluster", + }, + expectError: "must specify -r/--replicas with a non-negative integer", + }, + { + name: "should error when replicas is negative", + opts: &ScaleClusterOptions{ + configFlags: &genericclioptions.ConfigFlags{ + KubeConfig: &kubeConfigWithCurrentContext, + }, + ioStreams: &testStreams, + cluster: "test-cluster", + replicas: ptr.To(int32(-1)), + }, + expectError: "must specify -r/--replicas with a non-negative integer", + }, + { + name: "successful validation call", + opts: &ScaleClusterOptions{ + configFlags: &genericclioptions.ConfigFlags{ + KubeConfig: &kubeConfigWithCurrentContext, + }, + ioStreams: &testStreams, + cluster: "test-cluster", + replicas: ptr.To(int32(4)), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.opts.Validate() + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestRayScaleClusterRun(t *testing.T) { + testStreams, _, _, _ := genericclioptions.NewTestIOStreams() + + testNamespace, workerGroup, cluster := "test-context", "worker-group-1", "my-cluster" + desiredReplicas := int32(3) + + tests := []struct { + name string + expectedOutput string + expectedError string + rayClusters []runtime.Object + }{ + { + name: "should error when cluster doesn't exist", + rayClusters: []runtime.Object{}, + expectedError: "failed to scale worker group", + }, + { + name: "should error when worker group doesn't exist", + rayClusters: []runtime.Object{ + &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: cluster, + Namespace: testNamespace, + }, + Spec: rayv1.RayClusterSpec{ + WorkerGroupSpecs: []rayv1.WorkerGroupSpec{}, + }, + }, + }, + expectedError: fmt.Sprintf("worker group %s not found", workerGroup), + }, + { + name: "should not do anything when the desired replicas is the same as the current replicas", + rayClusters: []runtime.Object{ + &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: cluster, + Namespace: testNamespace, + }, + Spec: rayv1.RayClusterSpec{ + WorkerGroupSpecs: []rayv1.WorkerGroupSpec{ + { + GroupName: workerGroup, + Replicas: &desiredReplicas, + }, + }, + }, + }, + }, + expectedOutput: fmt.Sprintf("already has %d replicas", desiredReplicas), + }, + { + name: "should succeed when arguments are valid", + rayClusters: []runtime.Object{ + &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: cluster, + Namespace: testNamespace, + }, + Spec: rayv1.RayClusterSpec{ + WorkerGroupSpecs: []rayv1.WorkerGroupSpec{ + { + GroupName: workerGroup, + Replicas: ptr.To(int32(1)), + }, + }, + }, + }, + }, + expectedOutput: fmt.Sprintf("Scaled worker group %s", workerGroup), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fakeScaleClusterOptions := ScaleClusterOptions{ + configFlags: &genericclioptions.ConfigFlags{ + Namespace: &testNamespace, + }, + ioStreams: &testStreams, + replicas: &desiredReplicas, + workerGroup: workerGroup, + cluster: cluster, + } + + kubeClientSet := kubefake.NewClientset() + rayClient := rayClientFake.NewSimpleClientset(tc.rayClusters...) + k8sClients := client.NewClientForTesting(kubeClientSet, rayClient) + + var buf bytes.Buffer + err := fakeScaleClusterOptions.Run(context.Background(), k8sClients, &buf) + + if tc.expectedError == "" { + require.NoError(t, err) + assert.Contains(t, buf.String(), tc.expectedOutput) + } else { + assert.ErrorContains(t, err, tc.expectedError) + } + }) + } +}