Skip to content

Commit

Permalink
Add more synchronization around the shutdown of the IMEX manager
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Klues <kklues@nvidia.com>
  • Loading branch information
klueska committed Oct 23, 2024
1 parent c50aa21 commit 9679ccf
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 33 deletions.
79 changes: 49 additions & 30 deletions cmd/nvidia-dra-controller/imex.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ package main
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"sync"
"time"

v1 "k8s.io/api/core/v1"
Expand All @@ -43,25 +41,30 @@ const (
ImexChannelLimit = 128
)

type ImexManager struct {
waitGroup sync.WaitGroup
clientset kubernetes.Interface
}

type DriverResources resourceslice.DriverResources

func StartIMEXManager(ctx context.Context, config *Config) error {
func StartIMEXManager(ctx context.Context, config *Config) (*ImexManager, error) {
// Build a client set config
csconfig, err := config.flags.kubeClientConfig.NewClientSetConfig()
if err != nil {
return fmt.Errorf("error creating client set config: %w", err)
return nil, fmt.Errorf("error creating client set config: %w", err)
}

// Create a new clientset
clientset, err := kubernetes.NewForConfig(csconfig)
if err != nil {
return fmt.Errorf("error creating dynamic client: %w", err)
return nil, fmt.Errorf("error creating dynamic client: %w", err)
}

// Fetch the current Pod object
pod, err := clientset.CoreV1().Pods(config.flags.namespace).Get(ctx, config.flags.podName, metav1.GetOptions{})
if err != nil {
return fmt.Errorf("error fetching pod: %w", err)
return nil, fmt.Errorf("error fetching pod: %w", err)
}

// Set the owner of the ResourceSlices we will create
Expand All @@ -72,36 +75,39 @@ func StartIMEXManager(ctx context.Context, config *Config) error {
UID: pod.UID,
}

// Create the manager itself
m := &ImexManager{
clientset: clientset,
}

// Stream added/removed IMEX domains from nodes over time
klog.Info("Start streaming IMEX domains from nodes...")
addedDomainsCh, removedDomainsCh, err := streamImexDomains(ctx, clientset)
addedDomainsCh, removedDomainsCh, err := m.streamImexDomains(ctx)
if err != nil {
return fmt.Errorf("error streaming IMEX domains: %w", err)
return nil, fmt.Errorf("error streaming IMEX domains: %w", err)
}

// Add/Remove resource slices from IMEX domains as they come and go
klog.Info("Start publishing IMEX channels to ResourceSlices...")
err = manageResourceSlices(ctx, clientset, owner, addedDomainsCh, removedDomainsCh)
err = m.manageResourceSlices(ctx, owner, addedDomainsCh, removedDomainsCh)
if err != nil {
return fmt.Errorf("error managing resource slices: %w", err)
return nil, fmt.Errorf("error managing resource slices: %w", err)
}

return nil
return m, nil
}

// manageResourceSlices reacts to added and removed IMEX domains and triggers the creation / removal of resource slices accordingly.
func manageResourceSlices(ctx context.Context, clientset kubernetes.Interface, owner resourceslice.Owner, addedDomainsCh <-chan string, removedDomainsCh <-chan string) error {
func (m *ImexManager) manageResourceSlices(ctx context.Context, owner resourceslice.Owner, addedDomainsCh <-chan string, removedDomainsCh <-chan string) error {
driverResources := resourceslice.DriverResources{}
controller, err := resourceslice.StartController(ctx, clientset, DriverName, owner, &driverResources)
controller, err := resourceslice.StartController(ctx, m.clientset, DriverName, owner, &driverResources)
if err != nil {
return fmt.Errorf("error starting resource slice controller: %w", err)
}

// Setup signal catching
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT)

m.waitGroup.Add(1)
go func() {
defer m.waitGroup.Done()
for {
select {
case addedDomain := <-addedDomainsCh:
Expand All @@ -116,12 +122,7 @@ func manageResourceSlices(ctx context.Context, clientset kubernetes.Interface, o
delete(newDriverResources.Pools, removedDomain)
controller.Update(&newDriverResources)
driverResources = newDriverResources
case <-sigs:
controller.Stop()
err = cleanupImexChannels(ctx, clientset)
if err != nil {
klog.Errorf("error cleaning up resource slices: %v", err)
}
case <-ctx.Done():
return
}
}
Expand All @@ -130,6 +131,20 @@ func manageResourceSlices(ctx context.Context, clientset kubernetes.Interface, o
return nil
}

// Stop stops a running ImexManager.
func (m *ImexManager) Stop() error {
if m == nil {
return nil
}

m.waitGroup.Wait()
if err := m.cleanupImexChannels(); err != nil {
return fmt.Errorf("error cleaning up resource slices: %w", err)
}

return nil
}

// DeepCopy will perform a deep copy of the provided DriverResources.
func (d DriverResources) DeepCopy() resourceslice.DriverResources {
driverResources := resourceslice.DriverResources{
Expand All @@ -142,7 +157,7 @@ func (d DriverResources) DeepCopy() resourceslice.DriverResources {
}

// streamImexDomains returns two channels that streams imexDomans that are added and removed from nodes over time.
func streamImexDomains(ctx context.Context, clientset kubernetes.Interface) (<-chan string, <-chan string, error) {
func (m *ImexManager) streamImexDomains(ctx context.Context) (<-chan string, <-chan string, error) {
// Create channels to stream IMEX domain ids that are added / removed
addedDomainCh := make(chan string)
removedDomainCh := make(chan string)
Expand All @@ -159,7 +174,7 @@ func streamImexDomains(ctx context.Context, clientset kubernetes.Interface) (<-c

// Create a shared informer factory for nodes
informerFactory := informers.NewSharedInformerFactoryWithOptions(
clientset,
m.clientset,
time.Minute*10, // Resync period
informers.WithTweakListOptions(func(options *metav1.ListOptions) {
options.LabelSelector = labelSelector
Expand Down Expand Up @@ -218,7 +233,11 @@ func streamImexDomains(ctx context.Context, clientset kubernetes.Interface) (<-c
}

// Start the informer and wait for it to sync
go informerFactory.Start(ctx.Done())
m.waitGroup.Add(1)
go func() {
defer m.waitGroup.Done()
informerFactory.Start(ctx.Done())
}()

// Wait for the informer caches to sync
if !cache.WaitForCacheSync(ctx.Done(), nodeInformer.HasSynced) {
Expand Down Expand Up @@ -273,19 +292,19 @@ func generateImexChannelPool(imexDomain string, numChannels int) resourceslice.P
}

// cleanupImexChannels removes all resource slices created by the IMEX manager.
func cleanupImexChannels(ctx context.Context, clientset kubernetes.Interface) error {
func (m *ImexManager) cleanupImexChannels() error {
// Delete all resource slices created by the IMEX manager
ops := metav1.ListOptions{
FieldSelector: fmt.Sprintf("%s=%s", resourceapi.ResourceSliceSelectorDriver, DriverName),
}
l, err := clientset.ResourceV1alpha3().ResourceSlices().List(ctx, ops)
l, err := m.clientset.ResourceV1alpha3().ResourceSlices().List(context.Background(), ops)
if err != nil {
return fmt.Errorf("error listing resource slices: %w", err)
}

for _, rs := range l.Items {
klog.Info("Deleting resource slice: ", rs.Name)
err := clientset.ResourceV1alpha3().ResourceSlices().Delete(ctx, rs.Name, metav1.DeleteOptions{})
err := m.clientset.ResourceV1alpha3().ResourceSlices().Delete(context.Background(), rs.Name, metav1.DeleteOptions{})
if err != nil {
return fmt.Errorf("error deleting resource slice %s: %w", rs.Name, err)
}
Expand Down
17 changes: 14 additions & 3 deletions cmd/nvidia-dra-controller/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
package main

import (
"context"
"fmt"
"net"
"net/http"
"net/http/pprof"
"os"
"os/signal"
"path"
"syscall"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
Expand Down Expand Up @@ -132,7 +135,7 @@ func newApp() *cli.App {
return flags.loggingConfig.Apply()
},
Action: func(c *cli.Context) error {
ctx := c.Context
ctx, cancel := context.WithCancel(c.Context)

Check failure on line 138 in cmd/nvidia-dra-controller/main.go

View workflow job for this annotation

GitHub Actions / check

lostcancel: the cancel function is not used on all paths (possible context leak) (govet)
mux := http.NewServeMux()
flags.deviceClasses = sets.New[string](c.StringSlice("device-classes")...)

Expand All @@ -154,14 +157,22 @@ func newApp() *cli.App {
}
}

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT)

var imexManager *ImexManager
if flags.deviceClasses.Has(ImexChannelType) {
err = StartIMEXManager(ctx, config)
imexManager, err = StartIMEXManager(ctx, config)
if err != nil {
return fmt.Errorf("start IMEX manager: %w", err)
}
}

<-ctx.Done()
<-sigs
cancel()
if err := imexManager.Stop(); err != nil {
return fmt.Errorf("stop IMEX manager: %w", err)
}

return nil
},
Expand Down

0 comments on commit 9679ccf

Please sign in to comment.