Skip to content

Commit

Permalink
Merge pull request #185 from klueska/rename-to-gpu-plugin
Browse files Browse the repository at this point in the history
Add support to selectively decide which device classes to support via helm
  • Loading branch information
klueska authored Oct 21, 2024
2 parents 215a49a + e8e3e43 commit 322a3a7
Show file tree
Hide file tree
Showing 18 changed files with 274 additions and 64 deletions.
18 changes: 15 additions & 3 deletions cmd/nvidia-dra-controller/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/urfave/cli/v2"

"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/component-base/metrics/legacyregistry"
"k8s.io/klog/v2"

Expand All @@ -49,6 +50,8 @@ type Flags struct {
httpEndpoint string
metricsPath string
profilePath string

deviceClasses sets.Set[string]
}

type Config struct {
Expand Down Expand Up @@ -105,6 +108,12 @@ func newApp() *cli.App {
Destination: &flags.profilePath,
EnvVars: []string{"PPROF_PATH"},
},
&cli.StringSliceFlag{
Name: "device-classes",
Usage: "The supported set of DRA device classes",
Value: cli.NewStringSlice(GpuDeviceType, MigDeviceType, ImexChannelType),
EnvVars: []string{"DEVICE_CLASSES"},
},
}

cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...)
Expand All @@ -125,6 +134,7 @@ func newApp() *cli.App {
Action: func(c *cli.Context) error {
ctx := c.Context
mux := http.NewServeMux()
flags.deviceClasses = sets.New[string](c.StringSlice("device-classes")...)

clientSets, err := flags.kubeClientConfig.NewClientSets()
if err != nil {
Expand All @@ -144,9 +154,11 @@ func newApp() *cli.App {
}
}

err = StartIMEXManager(ctx, config)
if err != nil {
return fmt.Errorf("start IMEX manager: %w", err)
if flags.deviceClasses.Has(ImexChannelType) {
err = StartIMEXManager(ctx, config)
if err != nil {
return fmt.Errorf("start IMEX manager: %w", err)
}
}

<-ctx.Done()
Expand Down
24 changes: 24 additions & 0 deletions cmd/nvidia-dra-controller/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package main

const (
GpuDeviceType = "gpu"
MigDeviceType = "mig"
ImexChannelType = "imex"
UnknownDeviceType = "unknown"
)
26 changes: 12 additions & 14 deletions cmd/nvidia-dra-plugin/cdi.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ func (cdi *CDIHandler) CreateStandardDeviceSpecFile(allocatable AllocatableDevic
return fmt.Errorf("failed to get common CDI spec edits: %w", err)
}

// Make sure that NVIDIA_VISIBLE_DEVICES is set to void to avoid the
// nvidia-container-runtime honoring it in addition to the underlying
// runtime honoring CDI.
commonEdits.ContainerEdits.Env = append(
commonEdits.ContainerEdits.Env,
"NVIDIA_VISIBLE_DEVICES=void")

// Generate device specs for all full GPUs and MIG devices.
var deviceSpecs []cdispec.Device
for _, device := range allocatable {
Expand Down Expand Up @@ -223,25 +230,16 @@ func (cdi *CDIHandler) CreateClaimSpecFile(claimUID string, preparedDevices Prep
// Generate claim specific specs for each device.
var deviceSpecs []cdispec.Device
for _, group := range preparedDevices {
// Include this per-device, rather than as a top-level edit so that
// each device spec is never empty and the spec file gets created
// without error.
claimDeviceEdits := cdiapi.ContainerEdits{
ContainerEdits: &cdispec.ContainerEdits{
Env: []string{
"NVIDIA_VISIBLE_DEVICES=void",
},
},
// If there are no edits passed back as prt of the device config state, skip it
if group.ConfigState.containerEdits == nil {
continue
}

// Apply any edits passed back as part of the device config state.
claimDeviceEdits.Append(group.ConfigState.containerEdits)

// Apply edits to all devices.
// Apply any edits passed back as part of the device config state to all devices
for _, device := range group.Devices {
deviceSpec := cdispec.Device{
Name: fmt.Sprintf("%s-%s", claimUID, device.CanonicalName()),
ContainerEdits: *claimDeviceEdits.ContainerEdits,
ContainerEdits: *group.ConfigState.containerEdits.ContainerEdits,
}

deviceSpecs = append(deviceSpecs, deviceSpec)
Expand Down
2 changes: 1 addition & 1 deletion cmd/nvidia-dra-plugin/device_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func NewDeviceState(ctx context.Context, config *Config) (*DeviceState, error) {
return nil, fmt.Errorf("failed to create device library: %w", err)
}

allocatable, err := nvdevlib.enumerateAllPossibleDevices()
allocatable, err := nvdevlib.enumerateAllPossibleDevices(config)
if err != nil {
return nil, fmt.Errorf("error enumerating all possible devices: %w", err)
}
Expand Down
9 changes: 7 additions & 2 deletions cmd/nvidia-dra-plugin/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (

type driver struct {
sync.Mutex
doneCh chan struct{}
client coreclientset.Interface
plugin kubeletplugin.DRAPlugin
state *DeviceState
Expand Down Expand Up @@ -61,6 +60,12 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) {
}
driver.plugin = plugin

// If not responsible for advertising GPUs or MIG devices, we are done
if !(config.flags.deviceClasses.Has(GpuDeviceType) || config.flags.deviceClasses.Has(MigDeviceType)) {
return driver, nil
}

// Otherwise, enumerate the set of GPU and MIG devices and publish them
var resources kubeletplugin.Resources
for _, device := range state.allocatable {
// Explicitly exclude IMEX channels from being advertised here. They
Expand All @@ -79,7 +84,7 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) {
}

func (d *driver) Shutdown(ctx context.Context) error {
close(d.doneCh)
d.plugin.Stop()
return nil
}

Expand Down
10 changes: 10 additions & 0 deletions cmd/nvidia-dra-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/urfave/cli/v2"

"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/klog/v2"

"github.com/NVIDIA/k8s-dra-driver/internal/info"
Expand All @@ -50,6 +51,7 @@ type Flags struct {
containerDriverRoot string
hostDriverRoot string
nvidiaCTKPath string
deviceClasses sets.Set[string]
}

type Config struct {
Expand Down Expand Up @@ -112,6 +114,12 @@ func newApp() *cli.App {
Destination: &flags.nvidiaCTKPath,
EnvVars: []string{"NVIDIA_CTK_PATH"},
},
&cli.StringSliceFlag{
Name: "device-classes",
Usage: "The supported set of DRA device classes",
Value: cli.NewStringSlice(GpuDeviceType, MigDeviceType, ImexChannelType),
EnvVars: []string{"DEVICE_CLASSES"},
},
}
cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...)
cliFlags = append(cliFlags, flags.loggingConfig.Flags()...)
Expand All @@ -130,6 +138,8 @@ func newApp() *cli.App {
},
Action: func(c *cli.Context) error {
ctx := c.Context
flags.deviceClasses = sets.New[string](c.StringSlice("device-classes")...)

clientSets, err := flags.kubeClientConfig.NewClientSets()
if err != nil {
return fmt.Errorf("create client: %w", err)
Expand Down
68 changes: 52 additions & 16 deletions cmd/nvidia-dra-plugin/nvlib.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,36 +108,66 @@ func (l deviceLib) alwaysShutdown() {
}
}

func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
func (l deviceLib) enumerateAllPossibleDevices(config *Config) (AllocatableDevices, error) {
alldevices := make(AllocatableDevices)
deviceClasses := config.flags.deviceClasses

if deviceClasses.Has(GpuDeviceType) || deviceClasses.Has(MigDeviceType) {
gms, err := l.enumerateGpusAndMigDevices(config)
if err != nil {
return nil, fmt.Errorf("error enumerating IMEX devices: %w", err)
}
for k, v := range gms {
alldevices[k] = v
}
}

if deviceClasses.Has(ImexChannelType) {
imex, err := l.enumerateImexChannels(config)
if err != nil {
return nil, fmt.Errorf("error enumerating IMEX devices: %w", err)
}
for k, v := range imex {
alldevices[k] = v
}
}

return alldevices, nil
}

func (l deviceLib) enumerateGpusAndMigDevices(config *Config) (AllocatableDevices, error) {
if err := l.Init(); err != nil {
return nil, err
}
defer l.alwaysShutdown()

alldevices := make(AllocatableDevices)
devices := make(AllocatableDevices)
deviceClasses := config.flags.deviceClasses
err := l.VisitDevices(func(i int, d nvdev.Device) error {
gpuInfo, err := l.getGpuInfo(i, d)
if err != nil {
return fmt.Errorf("error getting info for GPU %d: %w", i, err)
}

migs, err := l.getMigDevices(gpuInfo)
if err != nil {
return fmt.Errorf("error getting MIG devices for GPU %d: %w", i, err)
}

for _, migDeviceInfo := range migs {
if deviceClasses.Has(GpuDeviceType) && !gpuInfo.migEnabled {
deviceInfo := &AllocatableDevice{
Mig: migDeviceInfo,
Gpu: gpuInfo,
}
alldevices[migDeviceInfo.CanonicalName()] = deviceInfo
devices[gpuInfo.CanonicalName()] = deviceInfo
}

if !gpuInfo.migEnabled && len(migs) == 0 {
deviceInfo := &AllocatableDevice{
Gpu: gpuInfo,
if deviceClasses.Has(MigDeviceType) {
migs, err := l.getMigDevices(gpuInfo)
if err != nil {
return fmt.Errorf("error getting MIG devices for GPU %d: %w", i, err)
}

for _, migDeviceInfo := range migs {
deviceInfo := &AllocatableDevice{
Mig: migDeviceInfo,
}
devices[migDeviceInfo.CanonicalName()] = deviceInfo
}
alldevices[gpuInfo.CanonicalName()] = deviceInfo
}

return nil
Expand All @@ -146,6 +176,12 @@ func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
return nil, fmt.Errorf("error visiting devices: %w", err)
}

return devices, nil
}

func (l deviceLib) enumerateImexChannels(config *Config) (AllocatableDevices, error) {
devices := make(AllocatableDevices)

imexChannelCount, err := l.getImexChannelCount()
if err != nil {
return nil, fmt.Errorf("error getting IMEX channel count: %w", err)
Expand All @@ -157,10 +193,10 @@ func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
deviceInfo := &AllocatableDevice{
ImexChannel: imexChannelInfo,
}
alldevices[imexChannelInfo.CanonicalName()] = deviceInfo
devices[imexChannelInfo.CanonicalName()] = deviceInfo
}

return alldevices, nil
return devices, nil
}

func (l deviceLib) getGpuInfo(index int, device nvdev.Device) (*GpuInfo, error) {
Expand Down
2 changes: 1 addition & 1 deletion cmd/nvidia-dra-plugin/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package main
const (
GpuDeviceType = "gpu"
MigDeviceType = "mig"
ImexChannelType = "imex-channel"
ImexChannelType = "imex"
UnknownDeviceType = "unknown"
)

Expand Down
9 changes: 4 additions & 5 deletions demo/clusters/kind/install-dra-driver.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@ set -o pipefail

source "${CURRENT_DIR}/scripts/common.sh"

kubectl label node -l node-role.x-k8s.io/worker --overwrite nvidia.com/dra.kubelet-plugin=true
kubectl label node -l node-role.x-k8s.io/control-plane --overwrite nvidia.com/dra.controller=true

helm upgrade -i --create-namespace --namespace nvidia-dra-driver nvidia ${PROJECT_DIR}/deployments/helm/k8s-dra-driver \
deviceClasses=${1:-"gpu,mig,imex"}
helm upgrade -i --create-namespace --namespace nvidia nvidia-dra-driver ${PROJECT_DIR}/deployments/helm/k8s-dra-driver \
--set deviceClasses="{${deviceClasses}}" \
${NVIDIA_DRIVER_ROOT:+--set nvidiaDriverRoot=${NVIDIA_DRIVER_ROOT}} \
--wait

set +x
printf '\033[0;32m'
echo "Driver installation complete:"
kubectl get pod -n nvidia-dra-driver
kubectl get pod -n nvidia
printf '\033[0m'
32 changes: 32 additions & 0 deletions deployments/helm/k8s-dra-driver/templates/_helpers.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,35 @@ Create the name of the service account to use
{{- default "default" .Values.serviceAccount.name }}
{{- end }}
{{- end }}

{{/*
Check for the existence of an element in a list
*/}}
{{- define "k8s-dra-driver.listHas" -}}
{{- $listToCheck := index . 0 }}
{{- $valueToCheck := index . 1 }}

{{- $found := "" -}}
{{- range $listToCheck}}
{{- if eq . $valueToCheck }}
{{- $found = "true" -}}
{{- end }}
{{- end }}
{{- $found -}}
{{- end }}

{{/*
Filter a list by a set of valid values
*/}}
{{- define "k8s-dra-driver.filterList" -}}
{{- $listToFilter := index . 0 }}
{{- $validValues := index . 1 }}

{{- $result := list -}}
{{- range $validValues}}
{{- if include "k8s-dra-driver.listHas" (list $listToFilter .) }}
{{- $result = append $result . }}
{{- end }}
{{- end }}
{{- $result -}}
{{- end -}}
Loading

0 comments on commit 322a3a7

Please sign in to comment.