Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the IMEX controller code #189

Merged
merged 2 commits into from
Oct 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 150 additions & 97 deletions cmd/nvidia-dra-controller/imex.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package main

import (
"context"
"errors"
"fmt"
"strings"
"sync"
Expand All @@ -41,17 +42,28 @@ const (
ImexDomainLabel = "nvidia.com/gpu.imex-domain"
ResourceSliceImexChannelLimit = 128
DriverImexChannelLimit = 2048
RetryTimeout = 1 * time.Minute
)

type ImexManager struct {
waitGroup sync.WaitGroup
clientset kubernetes.Interface
}
// transientError defines an error indicating that it is transient.
type transientError struct{ error }
ArangoGutierrez marked this conversation as resolved.
Show resolved Hide resolved

// imexDomainOffsets represents the offset for assigning IMEX channels
// to ResourceSlices for each <imex-domain, cliqueid> combination.
type imexDomainOffsets map[string]map[string]int

type ImexManager struct {
driverName string
resourceSliceImexChannelLimit int
driverImexChannelLimit int
retryTimeout time.Duration
waitGroup sync.WaitGroup
clientset kubernetes.Interface
imexDomainOffsets imexDomainOffsets
owner resourceslice.Owner
driverResources *resourceslice.DriverResources
}

func StartIMEXManager(ctx context.Context, config *Config) (*ImexManager, error) {
// Build a client set config
csconfig, err := config.flags.kubeClientConfig.NewClientSetConfig()
Expand Down Expand Up @@ -79,21 +91,25 @@ func StartIMEXManager(ctx context.Context, config *Config) (*ImexManager, error)
UID: pod.UID,
}

// Create the manager itself
m := &ImexManager{
clientset: clientset,
// Create a new set of DriverResources
driverResources := &resourceslice.DriverResources{
Pools: make(map[string]resourceslice.Pool),
ArangoGutierrez marked this conversation as resolved.
Show resolved Hide resolved
}

// Stream added/removed IMEX domains from nodes over time
klog.Info("Start streaming IMEX domains from nodes...")
addedDomainsCh, removedDomainsCh, err := m.streamImexDomains(ctx)
if err != nil {
return nil, fmt.Errorf("error streaming IMEX domains: %w", err)
// Create the manager itself
m := &ImexManager{
driverName: DriverName,
resourceSliceImexChannelLimit: ResourceSliceImexChannelLimit,
driverImexChannelLimit: DriverImexChannelLimit,
retryTimeout: RetryTimeout,
clientset: clientset,
owner: owner,
driverResources: driverResources,
imexDomainOffsets: make(imexDomainOffsets),
}

// Add/Remove resource slices from IMEX domains as they come and go
klog.Info("Start publishing IMEX channels to ResourceSlices...")
err = m.manageResourceSlices(ctx, owner, addedDomainsCh, removedDomainsCh)
err = m.manageResourceSlices(ctx)
if err != nil {
return nil, fmt.Errorf("error managing resource slices: %w", err)
}
Expand All @@ -102,35 +118,50 @@ func StartIMEXManager(ctx context.Context, config *Config) (*ImexManager, error)
}

// manageResourceSlices reacts to added and removed IMEX domains and triggers the creation / removal of resource slices accordingly.
func (m *ImexManager) manageResourceSlices(ctx context.Context, owner resourceslice.Owner, addedDomainsCh <-chan string, removedDomainsCh <-chan string) error {
driverResources := &resourceslice.DriverResources{}
controller, err := resourceslice.StartController(ctx, m.clientset, DriverName, owner, driverResources)
func (m *ImexManager) manageResourceSlices(ctx context.Context) error {
klog.Info("Start streaming IMEX domains from nodes...")
addedDomainsCh, removedDomainsCh, err := m.streamImexDomains(ctx)
if err != nil {
return fmt.Errorf("error streaming IMEX domains: %w", err)
}

klog.Info("Start publishing IMEX channels to ResourceSlices...")
controller, err := resourceslice.StartController(ctx, m.clientset, m.driverName, m.owner, m.driverResources)
if err != nil {
return fmt.Errorf("error starting resource slice controller: %w", err)
}

imexDomainOffsets := new(imexDomainOffsets)
m.waitGroup.Add(1)
go func() {
defer m.waitGroup.Done()
for {
select {
case addedDomain := <-addedDomainsCh:
offset, err := imexDomainOffsets.add(addedDomain, ResourceSliceImexChannelLimit, DriverImexChannelLimit)
if err != nil {
klog.Errorf("Error calculating channel offset for IMEX domain %s: %v", addedDomain, err)
return
}
klog.Infof("Adding channels for new IMEX domain: %v", addedDomain)
driverResources := driverResources.DeepCopy()
driverResources.Pools[addedDomain] = generateImexChannelPool(addedDomain, offset, ResourceSliceImexChannelLimit)
controller.Update(driverResources)
if err := m.addImexDomain(addedDomain); err != nil {
klog.Errorf("Error adding channels for IMEX domain %s: %v", addedDomain, err)
if errors.As(err, &transientError{}) {
ArangoGutierrez marked this conversation as resolved.
Show resolved Hide resolved
klog.Infof("Retrying adding channels for IMEX domain %s after %v", addedDomain, m.retryTimeout)
go func() {
time.Sleep(m.retryTimeout)
addedDomainsCh <- addedDomain
}()
}
}
controller.Update(m.driverResources)
case removedDomain := <-removedDomainsCh:
klog.Infof("Removing channels for removed IMEX domain: %v", removedDomain)
driverResources := driverResources.DeepCopy()
delete(driverResources.Pools, removedDomain)
imexDomainOffsets.remove(removedDomain)
controller.Update(driverResources)
if err := m.removeImexDomain(removedDomain); err != nil {
klog.Errorf("Error removing channels for IMEX domain %s: %v", removedDomain, err)
if errors.As(err, &transientError{}) {
klog.Infof("Retrying removing channels for IMEX domain %s after %v", removedDomain, m.retryTimeout)
go func() {
time.Sleep(m.retryTimeout)
removedDomainsCh <- removedDomain
}()
}
}
controller.Update(m.driverResources)
case <-ctx.Done():
return
}
Expand All @@ -155,8 +186,35 @@ func (m *ImexManager) Stop() error {
return nil
}

// addImexDomain adds an IMEX domain to be managed by the ImexManager.
func (m *ImexManager) addImexDomain(imexDomain string) error {
imexDomainID, cliqueID, err := splitImexDomain(imexDomain)
if err != nil {
return fmt.Errorf("error splitting IMEX domain '%s': %v", imexDomain, err)
}
ArangoGutierrez marked this conversation as resolved.
Show resolved Hide resolved
offset, err := m.imexDomainOffsets.add(imexDomainID, cliqueID, m.resourceSliceImexChannelLimit, m.driverImexChannelLimit)
if err != nil {
return fmt.Errorf("error setting offset for IMEX channels: %w", err)
}
m.driverResources = m.driverResources.DeepCopy()
m.driverResources.Pools[imexDomain] = generateImexChannelPool(imexDomain, offset, m.resourceSliceImexChannelLimit)
return nil
}

// removeImexDomain removes an IMEX domain from being managed by the ImexManager.
func (m *ImexManager) removeImexDomain(imexDomain string) error {
imexDomainID, cliqueID, err := splitImexDomain(imexDomain)
if err != nil {
return fmt.Errorf("error splitting IMEX domain '%s': %v", imexDomain, err)
}
m.imexDomainOffsets.remove(imexDomainID, cliqueID)
m.driverResources = m.driverResources.DeepCopy()
delete(m.driverResources.Pools, imexDomain)
return nil
}

// streamImexDomains returns two channels that streams imexDomans that are added and removed from nodes over time.
func (m *ImexManager) streamImexDomains(ctx context.Context) (<-chan string, <-chan string, error) {
func (m *ImexManager) streamImexDomains(ctx context.Context) (chan string, chan string, error) {
// Create channels to stream IMEX domain ids that are added / removed
addedDomainCh := make(chan string)
removedDomainCh := make(chan string)
Expand Down Expand Up @@ -246,50 +304,6 @@ func (m *ImexManager) streamImexDomains(ctx context.Context) (<-chan string, <-c
return addedDomainCh, removedDomainCh, nil
}

// generateImexChannelPool generates the contents of a ResourceSlice pool for a given IMEX domain.
func generateImexChannelPool(imexDomain string, startChannel int, numChannels int) resourceslice.Pool {
// Generate channels from startChannel to offset+numChannels
var devices []resourceapi.Device
for i := startChannel; i < (startChannel + numChannels); i++ {
d := resourceapi.Device{
Name: fmt.Sprintf("imex-channel-%d", i),
Basic: &resourceapi.BasicDevice{
Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{
"type": {
StringValue: ptr.To("imex-channel"),
},
"channel": {
IntValue: ptr.To(int64(i)),
},
},
},
}
devices = append(devices, d)
}

// Put them in a pool named after the IMEX domain with the IMEX domain label as a node selector
pool := resourceslice.Pool{
NodeSelector: &v1.NodeSelector{
NodeSelectorTerms: []v1.NodeSelectorTerm{
{
MatchExpressions: []v1.NodeSelectorRequirement{
{
Key: ImexDomainLabel,
Operator: v1.NodeSelectorOpIn,
Values: []string{
imexDomain,
},
},
},
},
},
},
Devices: devices,
}

return pool
}

// cleanupResourceSlices removes all resource slices created by the IMEX manager.
func (m *ImexManager) cleanupResourceSlices() error {
// Delete all resource slices created by the IMEX manager
Expand All @@ -312,28 +326,20 @@ func (m *ImexManager) cleanupResourceSlices() error {
}

// add sets the offset where an IMEX domain's channels should start counting from.
func (offsets imexDomainOffsets) add(imexDomain string, resourceSliceImexChannelLimit, driverImexChannelLimit int) (int, error) {
// Split the incoming imexDomain to split off its cliqueID
id := strings.SplitN(imexDomain, ".", 2)
if len(id) != 2 {
return -1, fmt.Errorf("error adding IMEX domain %s: invalid format", imexDomain)
}
imexDomain = id[0]
cliqueID := id[1]

func (offsets imexDomainOffsets) add(imexDomainID string, cliqueID string, resourceSliceImexChannelLimit, driverImexChannelLimit int) (int, error) {
// Check if the IMEX domain is already in the map
if _, ok := offsets[imexDomain]; !ok {
offsets[imexDomain] = make(map[string]int)
if _, ok := offsets[imexDomainID]; !ok {
offsets[imexDomainID] = make(map[string]int)
}

// Return early if the clique is already in the map
if offset, exists := offsets[imexDomain][cliqueID]; exists {
if offset, exists := offsets[imexDomainID][cliqueID]; exists {
return offset, nil
}

// Track used offsets for the current imexDomain
usedOffsets := make(map[int]struct{})
for _, v := range offsets[imexDomain] {
for _, v := range offsets[imexDomainID] {
usedOffsets[v] = struct{}{}
}

Expand All @@ -347,23 +353,70 @@ func (offsets imexDomainOffsets) add(imexDomain string, resourceSliceImexChannel

// If we reach the limit, return an error
if offset == driverImexChannelLimit {
return -1, fmt.Errorf("error adding IMEX domain %s: channel limit reached", imexDomain)
return -1, transientError{fmt.Errorf("channel limit reached")}
ArangoGutierrez marked this conversation as resolved.
Show resolved Hide resolved
}
offsets[imexDomain][cliqueID] = offset
offsets[imexDomainID][cliqueID] = offset

return offset, nil
}

func (offsets imexDomainOffsets) remove(imexDomain string) {
// remove removes the offset where an IMEX domain's channels should start counting from.
func (offsets imexDomainOffsets) remove(imexDomainID string, cliqueID string) {
delete(offsets[imexDomainID], cliqueID)
if len(offsets[imexDomainID]) == 0 {
delete(offsets, imexDomainID)
}
}

// splitImexDomain splits an imexDomain into its IMEX domain ID and its clique ID.
func splitImexDomain(imexDomain string) (string, string, error) {
id := strings.SplitN(imexDomain, ".", 2)
if len(id) != 2 {
return
return "", "", fmt.Errorf("splitting by '.' not equal to exactly 2 elements")
}
return id[0], id[1], nil
}

// generateImexChannelPool generates the contents of a ResourceSlice pool for a given IMEX domain.
func generateImexChannelPool(imexDomain string, startChannel int, numChannels int) resourceslice.Pool {
// Generate channels from startChannel to startChannel+numChannels
var devices []resourceapi.Device
for i := startChannel; i < (startChannel + numChannels); i++ {
d := resourceapi.Device{
Name: fmt.Sprintf("imex-channel-%d", i),
Basic: &resourceapi.BasicDevice{
Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{
"type": {
StringValue: ptr.To("imex-channel"),
},
"channel": {
IntValue: ptr.To(int64(i)),
},
},
},
}
devices = append(devices, d)
}
imexDomain = id[0]
cliqueID := id[1]

delete(offsets[imexDomain], cliqueID)
if len(offsets[imexDomain]) == 0 {
delete(offsets, imexDomain)
// Put them in a pool named after the IMEX domain with the IMEX domain label as a node selector
pool := resourceslice.Pool{
NodeSelector: &v1.NodeSelector{
NodeSelectorTerms: []v1.NodeSelectorTerm{
{
MatchExpressions: []v1.NodeSelectorRequirement{
{
Key: ImexDomainLabel,
Operator: v1.NodeSelectorOpIn,
Values: []string{
imexDomain,
},
},
},
},
},
},
Devices: devices,
}

return pool
}