From cb764ad747e6f1b980a837506ceac17d096d9e3f Mon Sep 17 00:00:00 2001 From: Jeev B Date: Sun, 29 Dec 2024 13:04:04 -0800 Subject: [PATCH] Fix custom gpu resource name specification --- Makefile | 4 ++++ .../go/tasks/pluginmachinery/flytek8s/utils.go | 4 +++- .../go/tasks/pluginmachinery/flytek8s/utils_test.go | 13 +++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index eacc4c69ae..cfa174c55c 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,10 @@ cmd/single/dist: export FLYTECONSOLE_VERSION ?= latest cmd/single/dist: script/get_flyteconsole_dist.sh +.PHONY: run +run: cmd/single/dist + POD_NAMESPACE=flyte go run -tags console cmd/main.go start --config ~/.flyte/flyte-single-binary-local.yaml + .PHONY: compile compile: cmd/single/dist go build -tags console -v -o flyte -ldflags=$(LD_FLAGS) ./cmd/ diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils.go index fab4f84997..e4a206bbab 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils.go @@ -7,6 +7,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" pluginmachinery_core "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" ) func ToK8sEnvVar(env []*core.KeyValuePair) []v1.EnvVar { @@ -20,6 +21,7 @@ func ToK8sEnvVar(env []*core.KeyValuePair) []v1.EnvVar { // TODO we should modify the container resources to contain a map of enum values? // Also we should probably create tolerations / taints, but we could do that as a post process func ToK8sResourceList(resources []*core.Resources_ResourceEntry) (v1.ResourceList, error) { + gpuResourceName := config.GetK8sPluginConfig().GpuResourceName k8sResources := make(v1.ResourceList, len(resources)) for _, r := range resources { rVal := r.GetValue() @@ -38,7 +40,7 @@ func ToK8sResourceList(resources []*core.Resources_ResourceEntry) (v1.ResourceLi } case core.Resources_GPU: if !v.IsZero() { - k8sResources[ResourceNvidiaGPU] = v + k8sResources[gpuResourceName] = v } case core.Resources_EPHEMERAL_STORAGE: if !v.IsZero() { diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils_test.go index 07e519cc80..1576ccd606 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils_test.go @@ -9,6 +9,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" ) func TestToK8sEnvVar(t *testing.T) { @@ -44,6 +45,18 @@ func TestToK8sResourceList(t *testing.T) { assert.Equal(t, resource.MustParse("1024Mi"), r[v1.ResourceMemory]) assert.Equal(t, resource.MustParse("1024Mi"), r[v1.ResourceEphemeralStorage]) } + { + gpuResourceName := v1.ResourceName("amd.com/gpu") + cfg := config.GetK8sPluginConfig() + cfg.GpuResourceName = gpuResourceName + assert.NoError(t, config.SetK8sPluginConfig(cfg)) + r, err := ToK8sResourceList([]*core.Resources_ResourceEntry{ + {Name: core.Resources_GPU, Value: "1"}, + }) + assert.NoError(t, err) + assert.NotEmpty(t, r) + assert.Equal(t, resource.MustParse("1"), r[gpuResourceName]) + } { r, err := ToK8sResourceList([]*core.Resources_ResourceEntry{}) assert.NoError(t, err)