Skip to content

Commit

Permalink
Merge pull request #189 from klueska/refactor-imex-controller
Browse files Browse the repository at this point in the history
Refactor the IMEX controller code
  • Loading branch information
klueska authored Oct 28, 2024
2 parents b80afe7 + b231043 commit 86cbafc
Showing 1 changed file with 150 additions and 97 deletions.
247 changes: 150 additions & 97 deletions cmd/nvidia-dra-controller/imex.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package main

import (
"context"
"errors"
"fmt"
"strings"
"sync"
Expand All @@ -41,17 +42,28 @@ const (
ImexDomainLabel = "nvidia.com/gpu.imex-domain"
ResourceSliceImexChannelLimit = 128
DriverImexChannelLimit = 2048
RetryTimeout = 1 * time.Minute
)

type ImexManager struct {
waitGroup sync.WaitGroup
clientset kubernetes.Interface
}
// transientError defines an error indicating that it is transient.
type transientError struct{ error }

// imexDomainOffsets represents the offset for assigning IMEX channels
// to ResourceSlices for each <imex-domain, cliqueid> combination.
type imexDomainOffsets map[string]map[string]int

type ImexManager struct {
driverName string
resourceSliceImexChannelLimit int
driverImexChannelLimit int
retryTimeout time.Duration
waitGroup sync.WaitGroup
clientset kubernetes.Interface
imexDomainOffsets imexDomainOffsets
owner resourceslice.Owner
driverResources *resourceslice.DriverResources
}

func StartIMEXManager(ctx context.Context, config *Config) (*ImexManager, error) {
// Build a client set config
csconfig, err := config.flags.kubeClientConfig.NewClientSetConfig()
Expand Down Expand Up @@ -79,21 +91,25 @@ func StartIMEXManager(ctx context.Context, config *Config) (*ImexManager, error)
UID: pod.UID,
}

// Create the manager itself
m := &ImexManager{
clientset: clientset,
// Create a new set of DriverResources
driverResources := &resourceslice.DriverResources{
Pools: make(map[string]resourceslice.Pool),
}

// Stream added/removed IMEX domains from nodes over time
klog.Info("Start streaming IMEX domains from nodes...")
addedDomainsCh, removedDomainsCh, err := m.streamImexDomains(ctx)
if err != nil {
return nil, fmt.Errorf("error streaming IMEX domains: %w", err)
// Create the manager itself
m := &ImexManager{
driverName: DriverName,
resourceSliceImexChannelLimit: ResourceSliceImexChannelLimit,
driverImexChannelLimit: DriverImexChannelLimit,
retryTimeout: RetryTimeout,
clientset: clientset,
owner: owner,
driverResources: driverResources,
imexDomainOffsets: make(imexDomainOffsets),
}

// Add/Remove resource slices from IMEX domains as they come and go
klog.Info("Start publishing IMEX channels to ResourceSlices...")
err = m.manageResourceSlices(ctx, owner, addedDomainsCh, removedDomainsCh)
err = m.manageResourceSlices(ctx)
if err != nil {
return nil, fmt.Errorf("error managing resource slices: %w", err)
}
Expand All @@ -102,35 +118,50 @@ func StartIMEXManager(ctx context.Context, config *Config) (*ImexManager, error)
}

// manageResourceSlices reacts to added and removed IMEX domains and triggers the creation / removal of resource slices accordingly.
func (m *ImexManager) manageResourceSlices(ctx context.Context, owner resourceslice.Owner, addedDomainsCh <-chan string, removedDomainsCh <-chan string) error {
driverResources := &resourceslice.DriverResources{}
controller, err := resourceslice.StartController(ctx, m.clientset, DriverName, owner, driverResources)
func (m *ImexManager) manageResourceSlices(ctx context.Context) error {
klog.Info("Start streaming IMEX domains from nodes...")
addedDomainsCh, removedDomainsCh, err := m.streamImexDomains(ctx)
if err != nil {
return fmt.Errorf("error streaming IMEX domains: %w", err)
}

klog.Info("Start publishing IMEX channels to ResourceSlices...")
controller, err := resourceslice.StartController(ctx, m.clientset, m.driverName, m.owner, m.driverResources)
if err != nil {
return fmt.Errorf("error starting resource slice controller: %w", err)
}

imexDomainOffsets := new(imexDomainOffsets)
m.waitGroup.Add(1)
go func() {
defer m.waitGroup.Done()
for {
select {
case addedDomain := <-addedDomainsCh:
offset, err := imexDomainOffsets.add(addedDomain, ResourceSliceImexChannelLimit, DriverImexChannelLimit)
if err != nil {
klog.Errorf("Error calculating channel offset for IMEX domain %s: %v", addedDomain, err)
return
}
klog.Infof("Adding channels for new IMEX domain: %v", addedDomain)
driverResources := driverResources.DeepCopy()
driverResources.Pools[addedDomain] = generateImexChannelPool(addedDomain, offset, ResourceSliceImexChannelLimit)
controller.Update(driverResources)
if err := m.addImexDomain(addedDomain); err != nil {
klog.Errorf("Error adding channels for IMEX domain %s: %v", addedDomain, err)
if errors.As(err, &transientError{}) {
klog.Infof("Retrying adding channels for IMEX domain %s after %v", addedDomain, m.retryTimeout)
go func() {
time.Sleep(m.retryTimeout)
addedDomainsCh <- addedDomain
}()
}
}
controller.Update(m.driverResources)
case removedDomain := <-removedDomainsCh:
klog.Infof("Removing channels for removed IMEX domain: %v", removedDomain)
driverResources := driverResources.DeepCopy()
delete(driverResources.Pools, removedDomain)
imexDomainOffsets.remove(removedDomain)
controller.Update(driverResources)
if err := m.removeImexDomain(removedDomain); err != nil {
klog.Errorf("Error removing channels for IMEX domain %s: %v", removedDomain, err)
if errors.As(err, &transientError{}) {
klog.Infof("Retrying removing channels for IMEX domain %s after %v", removedDomain, m.retryTimeout)
go func() {
time.Sleep(m.retryTimeout)
removedDomainsCh <- removedDomain
}()
}
}
controller.Update(m.driverResources)
case <-ctx.Done():
return
}
Expand All @@ -155,8 +186,35 @@ func (m *ImexManager) Stop() error {
return nil
}

// addImexDomain adds an IMEX domain to be managed by the ImexManager.
func (m *ImexManager) addImexDomain(imexDomain string) error {
imexDomainID, cliqueID, err := splitImexDomain(imexDomain)
if err != nil {
return fmt.Errorf("error splitting IMEX domain '%s': %v", imexDomain, err)
}
offset, err := m.imexDomainOffsets.add(imexDomainID, cliqueID, m.resourceSliceImexChannelLimit, m.driverImexChannelLimit)
if err != nil {
return fmt.Errorf("error setting offset for IMEX channels: %w", err)
}
m.driverResources = m.driverResources.DeepCopy()
m.driverResources.Pools[imexDomain] = generateImexChannelPool(imexDomain, offset, m.resourceSliceImexChannelLimit)
return nil
}

// removeImexDomain removes an IMEX domain from being managed by the ImexManager.
func (m *ImexManager) removeImexDomain(imexDomain string) error {
imexDomainID, cliqueID, err := splitImexDomain(imexDomain)
if err != nil {
return fmt.Errorf("error splitting IMEX domain '%s': %v", imexDomain, err)
}
m.imexDomainOffsets.remove(imexDomainID, cliqueID)
m.driverResources = m.driverResources.DeepCopy()
delete(m.driverResources.Pools, imexDomain)
return nil
}

// streamImexDomains returns two channels that streams imexDomans that are added and removed from nodes over time.
func (m *ImexManager) streamImexDomains(ctx context.Context) (<-chan string, <-chan string, error) {
func (m *ImexManager) streamImexDomains(ctx context.Context) (chan string, chan string, error) {
// Create channels to stream IMEX domain ids that are added / removed
addedDomainCh := make(chan string)
removedDomainCh := make(chan string)
Expand Down Expand Up @@ -246,50 +304,6 @@ func (m *ImexManager) streamImexDomains(ctx context.Context) (<-chan string, <-c
return addedDomainCh, removedDomainCh, nil
}

// generateImexChannelPool generates the contents of a ResourceSlice pool for a given IMEX domain.
func generateImexChannelPool(imexDomain string, startChannel int, numChannels int) resourceslice.Pool {
// Generate channels from startChannel to offset+numChannels
var devices []resourceapi.Device
for i := startChannel; i < (startChannel + numChannels); i++ {
d := resourceapi.Device{
Name: fmt.Sprintf("imex-channel-%d", i),
Basic: &resourceapi.BasicDevice{
Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{
"type": {
StringValue: ptr.To("imex-channel"),
},
"channel": {
IntValue: ptr.To(int64(i)),
},
},
},
}
devices = append(devices, d)
}

// Put them in a pool named after the IMEX domain with the IMEX domain label as a node selector
pool := resourceslice.Pool{
NodeSelector: &v1.NodeSelector{
NodeSelectorTerms: []v1.NodeSelectorTerm{
{
MatchExpressions: []v1.NodeSelectorRequirement{
{
Key: ImexDomainLabel,
Operator: v1.NodeSelectorOpIn,
Values: []string{
imexDomain,
},
},
},
},
},
},
Devices: devices,
}

return pool
}

// cleanupResourceSlices removes all resource slices created by the IMEX manager.
func (m *ImexManager) cleanupResourceSlices() error {
// Delete all resource slices created by the IMEX manager
Expand All @@ -312,28 +326,20 @@ func (m *ImexManager) cleanupResourceSlices() error {
}

// add sets the offset where an IMEX domain's channels should start counting from.
func (offsets imexDomainOffsets) add(imexDomain string, resourceSliceImexChannelLimit, driverImexChannelLimit int) (int, error) {
// Split the incoming imexDomain to split off its cliqueID
id := strings.SplitN(imexDomain, ".", 2)
if len(id) != 2 {
return -1, fmt.Errorf("error adding IMEX domain %s: invalid format", imexDomain)
}
imexDomain = id[0]
cliqueID := id[1]

func (offsets imexDomainOffsets) add(imexDomainID string, cliqueID string, resourceSliceImexChannelLimit, driverImexChannelLimit int) (int, error) {
// Check if the IMEX domain is already in the map
if _, ok := offsets[imexDomain]; !ok {
offsets[imexDomain] = make(map[string]int)
if _, ok := offsets[imexDomainID]; !ok {
offsets[imexDomainID] = make(map[string]int)
}

// Return early if the clique is already in the map
if offset, exists := offsets[imexDomain][cliqueID]; exists {
if offset, exists := offsets[imexDomainID][cliqueID]; exists {
return offset, nil
}

// Track used offsets for the current imexDomain
usedOffsets := make(map[int]struct{})
for _, v := range offsets[imexDomain] {
for _, v := range offsets[imexDomainID] {
usedOffsets[v] = struct{}{}
}

Expand All @@ -347,23 +353,70 @@ func (offsets imexDomainOffsets) add(imexDomain string, resourceSliceImexChannel

// If we reach the limit, return an error
if offset == driverImexChannelLimit {
return -1, fmt.Errorf("error adding IMEX domain %s: channel limit reached", imexDomain)
return -1, transientError{fmt.Errorf("channel limit reached")}
}
offsets[imexDomain][cliqueID] = offset
offsets[imexDomainID][cliqueID] = offset

return offset, nil
}

func (offsets imexDomainOffsets) remove(imexDomain string) {
// remove removes the offset where an IMEX domain's channels should start counting from.
func (offsets imexDomainOffsets) remove(imexDomainID string, cliqueID string) {
delete(offsets[imexDomainID], cliqueID)
if len(offsets[imexDomainID]) == 0 {
delete(offsets, imexDomainID)
}
}

// splitImexDomain splits an imexDomain into its IMEX domain ID and its clique ID.
func splitImexDomain(imexDomain string) (string, string, error) {
id := strings.SplitN(imexDomain, ".", 2)
if len(id) != 2 {
return
return "", "", fmt.Errorf("splitting by '.' not equal to exactly 2 elements")
}
return id[0], id[1], nil
}

// generateImexChannelPool generates the contents of a ResourceSlice pool for a given IMEX domain.
func generateImexChannelPool(imexDomain string, startChannel int, numChannels int) resourceslice.Pool {
// Generate channels from startChannel to startChannel+numChannels
var devices []resourceapi.Device
for i := startChannel; i < (startChannel + numChannels); i++ {
d := resourceapi.Device{
Name: fmt.Sprintf("imex-channel-%d", i),
Basic: &resourceapi.BasicDevice{
Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{
"type": {
StringValue: ptr.To("imex-channel"),
},
"channel": {
IntValue: ptr.To(int64(i)),
},
},
},
}
devices = append(devices, d)
}
imexDomain = id[0]
cliqueID := id[1]

delete(offsets[imexDomain], cliqueID)
if len(offsets[imexDomain]) == 0 {
delete(offsets, imexDomain)
// Put them in a pool named after the IMEX domain with the IMEX domain label as a node selector
pool := resourceslice.Pool{
NodeSelector: &v1.NodeSelector{
NodeSelectorTerms: []v1.NodeSelectorTerm{
{
MatchExpressions: []v1.NodeSelectorRequirement{
{
Key: ImexDomainLabel,
Operator: v1.NodeSelectorOpIn,
Values: []string{
imexDomain,
},
},
},
},
},
},
Devices: devices,
}

return pool
}

0 comments on commit 86cbafc

Please sign in to comment.