diff --git a/cmd/nvidia-dra-controller/imex.go b/cmd/nvidia-dra-controller/imex.go index 63218f0a..16249cd2 100644 --- a/cmd/nvidia-dra-controller/imex.go +++ b/cmd/nvidia-dra-controller/imex.go @@ -19,9 +19,7 @@ package main import ( "context" "fmt" - "os" - "os/signal" - "syscall" + "sync" "time" v1 "k8s.io/api/core/v1" @@ -43,25 +41,30 @@ const ( ImexChannelLimit = 128 ) +type ImexManager struct { + waitGroup sync.WaitGroup + clientset kubernetes.Interface +} + type DriverResources resourceslice.DriverResources -func StartIMEXManager(ctx context.Context, config *Config) error { +func StartIMEXManager(ctx context.Context, config *Config) (*ImexManager, error) { // Build a client set config csconfig, err := config.flags.kubeClientConfig.NewClientSetConfig() if err != nil { - return fmt.Errorf("error creating client set config: %w", err) + return nil, fmt.Errorf("error creating client set config: %w", err) } // Create a new clientset clientset, err := kubernetes.NewForConfig(csconfig) if err != nil { - return fmt.Errorf("error creating dynamic client: %w", err) + return nil, fmt.Errorf("error creating dynamic client: %w", err) } // Fetch the current Pod object pod, err := clientset.CoreV1().Pods(config.flags.namespace).Get(ctx, config.flags.podName, metav1.GetOptions{}) if err != nil { - return fmt.Errorf("error fetching pod: %w", err) + return nil, fmt.Errorf("error fetching pod: %w", err) } // Set the owner of the ResourceSlices we will create @@ -72,36 +75,39 @@ func StartIMEXManager(ctx context.Context, config *Config) error { UID: pod.UID, } + // Create the manager itself + m := &ImexManager{ + clientset: clientset, + } + // Stream added/removed IMEX domains from nodes over time klog.Info("Start streaming IMEX domains from nodes...") - addedDomainsCh, removedDomainsCh, err := streamImexDomains(ctx, clientset) + addedDomainsCh, removedDomainsCh, err := m.streamImexDomains(ctx) if err != nil { - return fmt.Errorf("error streaming IMEX domains: %w", err) + return nil, fmt.Errorf("error streaming IMEX domains: %w", err) } // Add/Remove resource slices from IMEX domains as they come and go klog.Info("Start publishing IMEX channels to ResourceSlices...") - err = manageResourceSlices(ctx, clientset, owner, addedDomainsCh, removedDomainsCh) + err = m.manageResourceSlices(ctx, owner, addedDomainsCh, removedDomainsCh) if err != nil { - return fmt.Errorf("error managing resource slices: %w", err) + return nil, fmt.Errorf("error managing resource slices: %w", err) } - return nil + return m, nil } // manageResourceSlices reacts to added and removed IMEX domains and triggers the creation / removal of resource slices accordingly. -func manageResourceSlices(ctx context.Context, clientset kubernetes.Interface, owner resourceslice.Owner, addedDomainsCh <-chan string, removedDomainsCh <-chan string) error { +func (m *ImexManager) manageResourceSlices(ctx context.Context, owner resourceslice.Owner, addedDomainsCh <-chan string, removedDomainsCh <-chan string) error { driverResources := resourceslice.DriverResources{} - controller, err := resourceslice.StartController(ctx, clientset, DriverName, owner, &driverResources) + controller, err := resourceslice.StartController(ctx, m.clientset, DriverName, owner, &driverResources) if err != nil { return fmt.Errorf("error starting resource slice controller: %w", err) } - // Setup signal catching - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) - + m.waitGroup.Add(1) go func() { + defer m.waitGroup.Done() for { select { case addedDomain := <-addedDomainsCh: @@ -116,12 +122,7 @@ func manageResourceSlices(ctx context.Context, clientset kubernetes.Interface, o delete(newDriverResources.Pools, removedDomain) controller.Update(&newDriverResources) driverResources = newDriverResources - case <-sigs: - controller.Stop() - err = cleanupImexChannels(ctx, clientset) - if err != nil { - klog.Errorf("error cleaning up resource slices: %v", err) - } + case <-ctx.Done(): return } } @@ -130,6 +131,20 @@ func manageResourceSlices(ctx context.Context, clientset kubernetes.Interface, o return nil } +// Stop stops a running ImexManager. +func (m *ImexManager) Stop() error { + if m == nil { + return nil + } + + m.waitGroup.Wait() + if err := m.cleanupImexChannels(); err != nil { + return fmt.Errorf("error cleaning up resource slices: %w", err) + } + + return nil +} + // DeepCopy will perform a deep copy of the provided DriverResources. func (d DriverResources) DeepCopy() resourceslice.DriverResources { driverResources := resourceslice.DriverResources{ @@ -142,7 +157,7 @@ func (d DriverResources) DeepCopy() resourceslice.DriverResources { } // streamImexDomains returns two channels that streams imexDomans that are added and removed from nodes over time. -func streamImexDomains(ctx context.Context, clientset kubernetes.Interface) (<-chan string, <-chan string, error) { +func (m *ImexManager) streamImexDomains(ctx context.Context) (<-chan string, <-chan string, error) { // Create channels to stream IMEX domain ids that are added / removed addedDomainCh := make(chan string) removedDomainCh := make(chan string) @@ -159,7 +174,7 @@ func streamImexDomains(ctx context.Context, clientset kubernetes.Interface) (<-c // Create a shared informer factory for nodes informerFactory := informers.NewSharedInformerFactoryWithOptions( - clientset, + m.clientset, time.Minute*10, // Resync period informers.WithTweakListOptions(func(options *metav1.ListOptions) { options.LabelSelector = labelSelector @@ -218,7 +233,11 @@ func streamImexDomains(ctx context.Context, clientset kubernetes.Interface) (<-c } // Start the informer and wait for it to sync - go informerFactory.Start(ctx.Done()) + m.waitGroup.Add(1) + go func() { + defer m.waitGroup.Done() + informerFactory.Start(ctx.Done()) + }() // Wait for the informer caches to sync if !cache.WaitForCacheSync(ctx.Done(), nodeInformer.HasSynced) { @@ -273,19 +292,19 @@ func generateImexChannelPool(imexDomain string, numChannels int) resourceslice.P } // cleanupImexChannels removes all resource slices created by the IMEX manager. -func cleanupImexChannels(ctx context.Context, clientset kubernetes.Interface) error { +func (m *ImexManager) cleanupImexChannels() error { // Delete all resource slices created by the IMEX manager ops := metav1.ListOptions{ FieldSelector: fmt.Sprintf("%s=%s", resourceapi.ResourceSliceSelectorDriver, DriverName), } - l, err := clientset.ResourceV1alpha3().ResourceSlices().List(ctx, ops) + l, err := m.clientset.ResourceV1alpha3().ResourceSlices().List(context.Background(), ops) if err != nil { return fmt.Errorf("error listing resource slices: %w", err) } for _, rs := range l.Items { klog.Info("Deleting resource slice: ", rs.Name) - err := clientset.ResourceV1alpha3().ResourceSlices().Delete(ctx, rs.Name, metav1.DeleteOptions{}) + err := m.clientset.ResourceV1alpha3().ResourceSlices().Delete(context.Background(), rs.Name, metav1.DeleteOptions{}) if err != nil { return fmt.Errorf("error deleting resource slice %s: %w", rs.Name, err) } diff --git a/cmd/nvidia-dra-controller/main.go b/cmd/nvidia-dra-controller/main.go index d4c78e8c..09750420 100644 --- a/cmd/nvidia-dra-controller/main.go +++ b/cmd/nvidia-dra-controller/main.go @@ -17,12 +17,15 @@ package main import ( + "context" "fmt" "net" "net/http" "net/http/pprof" "os" + "os/signal" "path" + "syscall" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -132,7 +135,7 @@ func newApp() *cli.App { return flags.loggingConfig.Apply() }, Action: func(c *cli.Context) error { - ctx := c.Context + ctx, cancel := context.WithCancel(c.Context) mux := http.NewServeMux() flags.deviceClasses = sets.New[string](c.StringSlice("device-classes")...) @@ -154,14 +157,22 @@ func newApp() *cli.App { } } + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) + + var imexManager *ImexManager if flags.deviceClasses.Has(ImexChannelType) { - err = StartIMEXManager(ctx, config) + imexManager, err = StartIMEXManager(ctx, config) if err != nil { return fmt.Errorf("start IMEX manager: %w", err) } } - <-ctx.Done() + <-sigs + cancel() + if err := imexManager.Stop(); err != nil { + return fmt.Errorf("stop IMEX manager: %w", err) + } return nil },