Skip to content

Commit 6939e9b

Browse files
committed
Adding multi-host indexing
Signed-off-by: Aaron Liang <aaronliang@google.com>
1 parent e92c46b commit 6939e9b

File tree

7 files changed

+235
-31
lines changed

7 files changed

+235
-31
lines changed

ray-operator/controllers/ray/common/pod.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717

1818
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
1919
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
20+
"github.com/ray-project/kuberay/ray-operator/pkg/features"
2021
)
2122

2223
const (
@@ -244,7 +245,7 @@ func getEnableProbesInjection() bool {
244245
}
245246

246247
// DefaultWorkerPodTemplate sets the config values
247-
func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string) corev1.PodTemplateSpec {
248+
func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string, replicaGrpName string, numHostIndex int) corev1.PodTemplateSpec {
248249
podTemplate := workerSpec.Template
249250
podTemplate.GenerateName = podName
250251
// Pods created by RayCluster should be restricted to the namespace of the RayCluster.
@@ -315,6 +316,11 @@ func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, wo
315316
podTemplate.Labels = make(map[string]string)
316317
}
317318
podTemplate.Labels = labelPod(rayv1.WorkerNode, instance.Name, workerSpec.GroupName, workerSpec.Template.ObjectMeta.Labels)
319+
// Add additional labels for RayMultihostIndexing
320+
multihostIndexingEnabled := features.Enabled(features.RayMulithostIndexing) && workerSpec.NumOfHosts > 1
321+
if multihostIndexingEnabled {
322+
podTemplate.Labels = addMultihostIndexingPodLabels(podTemplate.Labels, replicaGrpName, numHostIndex)
323+
}
318324
workerSpec.RayStartParams = setMissingRayStartParams(ctx, workerSpec.RayStartParams, rayv1.WorkerNode, headPort, fqdnRayIP)
319325

320326
initTemplateAnnotations(instance, &podTemplate)
@@ -628,6 +634,15 @@ func labelPod(rayNodeType rayv1.RayNodeType, rayClusterName string, groupName st
628634
return labels
629635
}
630636

637+
// addMultihostIndexingPodLabels returns labels that contain RayMultihostIndexing feature labels
638+
func addMultihostIndexingPodLabels(currentLabels map[string]string, replicaGrpName string, numHostIndex int) map[string]string {
639+
labels := currentLabels
640+
labels[utils.RayWorkerReplicaIndexKey] = replicaGrpName
641+
labels[utils.RayHostIndexKey] = strconv.Itoa(numHostIndex)
642+
643+
return labels
644+
}
645+
631646
func setInitContainerEnvVars(container *corev1.Container, fqdnRayIP string) {
632647
if len(container.Env) == 0 {
633648
container.Env = []corev1.EnvVar{}

ray-operator/controllers/ray/common/pod_test.go

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
2222
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
23+
"github.com/ray-project/kuberay/ray-operator/pkg/features"
2324
)
2425

2526
var testMemoryLimit = resource.MustParse("1Gi")
@@ -686,7 +687,7 @@ func TestBuildPod(t *testing.T) {
686687
worker := cluster.Spec.WorkerGroupSpecs[0]
687688
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
688689
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
689-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
690+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
690691
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP, defaultContainerEnvs)
691692

692693
// Check resources
@@ -760,7 +761,7 @@ func TestBuildPod_WithNoCPULimits(t *testing.T) {
760761
worker := cluster.Spec.WorkerGroupSpecs[0]
761762
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
762763
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
763-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
764+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
764765
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP, nil)
765766
expectedCommandArg = splitAndSort("ulimit -n 65536; ray start --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --num-gpus=3 --address=raycluster-sample-head-svc.default.svc.cluster.local:6379 --port=6379 --metrics-export-port=8080")
766767
actualCommandArg = splitAndSort(pod.Spec.Containers[0].Args[0])
@@ -791,7 +792,7 @@ func TestBuildPod_WithOverwriteCommand(t *testing.T) {
791792
worker := cluster.Spec.WorkerGroupSpecs[0]
792793
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
793794
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
794-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
795+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
795796
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP, nil)
796797
workerContainer := workerPod.Spec.Containers[utils.RayContainerIndex]
797798
assert.Equal(t, []string{"I am worker"}, workerContainer.Command)
@@ -846,7 +847,7 @@ func TestBuildPod_WithCreatedByRayService(t *testing.T) {
846847
worker := cluster.Spec.WorkerGroupSpecs[0]
847848
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
848849
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
849-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
850+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
850851
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP, nil)
851852

852853
val, ok = pod.Labels[utils.RayClusterServingServiceLabelKey]
@@ -902,7 +903,7 @@ func TestBuildPod_WithLoginBash(t *testing.T) {
902903
worker := cluster.Spec.WorkerGroupSpecs[0]
903904
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
904905
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
905-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
906+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
906907
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP, nil)
907908

908909
// Verify worker container command
@@ -1165,11 +1166,33 @@ func TestDefaultWorkerPodTemplateWithName(t *testing.T) {
11651166
expectedWorker := *worker.DeepCopy()
11661167

11671168
// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
1168-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379")
1169+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
11691170
assert.Empty(t, podTemplateSpec.ObjectMeta.Name)
11701171
assert.Equal(t, expectedWorker, worker)
11711172
}
11721173

1174+
func TestDeafultWorkerPodTemplateWithReplicaGrpAndIndex(t *testing.T) {
1175+
ctx := context.Background()
1176+
1177+
cluster := instance.DeepCopy()
1178+
1179+
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
1180+
worker := cluster.Spec.WorkerGroupSpecs[0]
1181+
1182+
features.SetFeatureGateDuringTest(t, features.RayMulithostIndexing, true)
1183+
1184+
worker.Template.ObjectMeta.Name = "ray-worker-test"
1185+
worker.NumOfHosts = 4
1186+
podName := cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
1187+
groupReplicaName := utils.GenerateRayWorkerReplicaGroupName(worker.GroupName)
1188+
1189+
// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
1190+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", groupReplicaName, 2)
1191+
assert.Empty(t, podTemplateSpec.ObjectMeta.Name)
1192+
assert.Equal(t, podTemplateSpec.Labels[utils.RayWorkerReplicaIndexKey], groupReplicaName)
1193+
assert.Equal(t, "2", podTemplateSpec.Labels[utils.RayHostIndexKey])
1194+
}
1195+
11731196
func containerPortExists(ports []corev1.ContainerPort, containerPort int32) error {
11741197
name := utils.MetricsPortName
11751198
for _, port := range ports {
@@ -1212,7 +1235,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) {
12121235
worker := cluster.Spec.WorkerGroupSpecs[0]
12131236
podName := cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
12141237
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
1215-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
1238+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
12161239
// DefaultWorkerPodTemplate will add the default metrics port if user doesn't specify it.
12171240
// Verify the default metrics port exists.
12181241
require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, int32(utils.DefaultMetricsPort)))
@@ -1222,7 +1245,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) {
12221245
ContainerPort: customMetricsPort,
12231246
}
12241247
cluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Ports = []corev1.ContainerPort{metricsPort}
1225-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
1248+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
12261249
// Verify the custom metrics port exists.
12271250
require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, customMetricsPort))
12281251
}
@@ -1261,7 +1284,7 @@ func TestDefaultWorkerPodTemplate_Autoscaling(t *testing.T) {
12611284

12621285
for name, tc := range tests {
12631286
t.Run(name, func(t *testing.T) {
1264-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379")
1287+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379", "", 0)
12651288
assert.Equal(t, tc.expectedRestartPolicy, podTemplateSpec.Spec.RestartPolicy)
12661289
})
12671290
}
@@ -1277,7 +1300,7 @@ func TestDefaultInitContainer(t *testing.T) {
12771300
expectedResult := len(cluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers) + 1
12781301

12791302
// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
1280-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379")
1303+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
12811304
numInitContainers := len(podTemplateSpec.Spec.InitContainers)
12821305
assert.Equal(t, expectedResult, numInitContainers, "A default init container is expected to be added.")
12831306

@@ -1336,7 +1359,7 @@ func TestDefaultInitContainerImagePullPolicy(t *testing.T) {
13361359
// set ray container imagePullPolicy
13371360
worker.Template.Spec.Containers[utils.RayContainerIndex].ImagePullPolicy = tc.imagePullPolicy
13381361

1339-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379")
1362+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
13401363

13411364
healthCheckContainer := podTemplateSpec.Spec.InitContainers[len(podTemplateSpec.Spec.InitContainers)-1]
13421365
assert.Equal(t, tc.expectedPullPolicy, healthCheckContainer.ImagePullPolicy, "The ImagePullPolicy of the init container should be the same as the Ray container.")

0 commit comments

Comments
 (0)