Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func init() {
}
karpenterScheme.Register(&karpv1.NodeClaim{}, &karpv1.NodeClaimList{})
karpenterScheme.Register(&karpv1.NodePool{}, &karpv1.NodePoolList{})
karpenterScheme.AddToScheme(scheme)
utilruntime.Must(karpenterScheme.AddToScheme(scheme))
}

//nolint:gocyclo
Expand Down
17 changes: 7 additions & 10 deletions internal/utils/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/samber/lo"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/utils/ptr"
)
Expand Down Expand Up @@ -73,16 +74,12 @@ var featureShortcutMap = map[string]struct {
}

type TensorFusionInfo struct {
Profile *tfv1.WorkloadProfileSpec
DynamicReplicas bool
EnabledReplicas *int32
WorkloadName string
ContainerNames []string
GenWorkload bool

// Pod mutating webhook can not get Pod UID sometimes,
// thus need pod controller to set the owner reference
PendingSetPodAsOwner bool
Profile *tfv1.WorkloadProfileSpec
DynamicReplicas bool
EnabledReplicas *int32
WorkloadName string
PodControllerRef *metav1.OwnerReference
ContainerNames []string
}

func AddOrOverrideTFClientMissingAnnotationsBeforePatch(pod *v1.Pod, tfInfo TensorFusionInfo) {
Expand Down
85 changes: 85 additions & 0 deletions internal/utils/owner_ref_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"
"fmt"

appsv1 "k8s.io/api/apps/v1"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
Expand Down Expand Up @@ -96,3 +99,85 @@ func FindFirstLevelOwnerReference(obj metav1.Object) *metav1.OwnerReference {
}
return &ownerRef
}

// FindRootControllerRef recursively finds the root controller reference for a given object (e.g. Pod).
func FindRootControllerRef(ctx context.Context, c client.Client, obj metav1.Object) (*metav1.OwnerReference, error) {
if metav1.GetControllerOfNoCopy(obj) == nil {
return nil, nil
}

namespace := obj.GetNamespace()
current := obj
for {
controllerRef := metav1.GetControllerOf(current)
if controllerRef == nil {
if rObj, ok := current.(runtime.Object); ok {
gvk := rObj.GetObjectKind().GroupVersionKind()
return metav1.NewControllerRef(current, gvk), nil
} else {
return nil, fmt.Errorf("not a runtime.Object")
}
}

unObj := &unstructured.Unstructured{}
unObj.SetAPIVersion(controllerRef.APIVersion)
unObj.SetKind(controllerRef.Kind)
err := c.Get(ctx, client.ObjectKey{Name: controllerRef.Name, Namespace: namespace}, unObj)
if err != nil {
// if not found, return controllerRef as root
if errors.IsNotFound(err) {
return controllerRef, nil
}
return nil, fmt.Errorf("get controller object: %w", err)
}

// Cast back to metav1.Object if possible
if metaObj, ok := any(unObj).(metav1.Object); ok {
current = metaObj
} else {
return nil, fmt.Errorf("unexpected type for controller object %s/%s", controllerRef.Kind, controllerRef.Name)
}
}
}

// GetPodControllerRef returns the controller reference for a Pod.
// For Pods that are indirectly controlled (e.g., by a Deployment or CronJob), return the indirect controller.
// For other cases, it returns the direct controller reference of the Pod.
// If the Pod has no controller reference, it returns nil.
func GetPodControllerRef(ctx context.Context, c client.Client, pod *corev1.Pod) (*metav1.OwnerReference, error) {
podControllerRef := metav1.GetControllerOf(pod)
if podControllerRef == nil {
return nil, nil
}

getControllerRef := func(obj client.Object) (*metav1.OwnerReference, error) {
if err := c.Get(ctx, client.ObjectKey{
Namespace: pod.Namespace,
Name: podControllerRef.Name,
}, obj); err != nil {
if errors.IsNotFound(err) {
return podControllerRef, nil
}
return nil, fmt.Errorf("failed to get %T: %w", obj, err)
}
return metav1.GetControllerOf(obj), nil
}

switch podControllerRef.Kind {
case "ReplicaSet":
if parentRef, err := getControllerRef(&appsv1.ReplicaSet{}); err != nil {
return nil, err
} else if parentRef != nil && parentRef.Kind == "Deployment" {
return parentRef, nil
}

case "Job":
if parentRef, err := getControllerRef(&batchv1.Job{}); err != nil {
return nil, err
} else if parentRef != nil && parentRef.Kind == "CronJob" {
return parentRef, nil
}
}

return podControllerRef, nil
}
252 changes: 252 additions & 0 deletions internal/utils/owner_ref_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

appsv1 "k8s.io/api/apps/v1"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
Expand Down Expand Up @@ -140,3 +141,254 @@ func TestFindRootOwnerReference(t *testing.T) {
require.Equal(t, "ReplicaSet", rootRef.Kind)
})
}

func TestFindRootControllerRef(t *testing.T) {
// Prepare the scheme
sch := runtime.NewScheme()
require.NoError(t, corev1.AddToScheme(sch))
require.NoError(t, appsv1.AddToScheme(sch))

t.Run("no controller returns nil", func(t *testing.T) {
pod := &corev1.Pod{
TypeMeta: metav1.TypeMeta{
APIVersion: "v1",
Kind: "Pod",
},
ObjectMeta: metav1.ObjectMeta{
Name: "mypod",
Namespace: "default",
UID: "uid-pod",
},
}

c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod).Build()

rootRef, err := utils.FindRootControllerRef(context.TODO(), c, pod)
require.NoError(t, err)
require.Nil(t, rootRef)
})

t.Run("hierarchy returns deployment", func(t *testing.T) {
controller := true
deployment := &appsv1.Deployment{
TypeMeta: metav1.TypeMeta{
APIVersion: "apps/v1",
Kind: "Deployment",
},
ObjectMeta: metav1.ObjectMeta{
Name: "mydeploy",
Namespace: "default",
UID: "uid-deploy",
},
}

rs := &appsv1.ReplicaSet{
TypeMeta: metav1.TypeMeta{
APIVersion: "apps/v1",
Kind: "ReplicaSet",
},
ObjectMeta: metav1.ObjectMeta{
Name: "myrs",
Namespace: "default",
UID: "uid-rs",
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: "apps/v1",
Kind: "Deployment",
Name: "mydeploy",
UID: deployment.UID,
Controller: &controller,
},
},
},
}

pod := &corev1.Pod{
TypeMeta: metav1.TypeMeta{
APIVersion: "v1",
Kind: "Pod",
},
ObjectMeta: metav1.ObjectMeta{
Name: "mypod",
Namespace: "default",
UID: "uid-pod",
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: "apps/v1",
Kind: "ReplicaSet",
Name: "myrs",
UID: rs.UID,
Controller: &controller,
},
},
},
}

c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod, rs, deployment).Build()

rootRef, err := utils.FindRootControllerRef(context.TODO(), c, pod)
require.NoError(t, err)
require.NotNil(t, rootRef)
require.Equal(t, "mydeploy", rootRef.Name)
require.Equal(t, "Deployment", rootRef.Kind)
})

t.Run("missing controller returns last found ref", func(t *testing.T) {
controller := true
pod := &corev1.Pod{
TypeMeta: metav1.TypeMeta{
APIVersion: "v1",
Kind: "Pod",
},
ObjectMeta: metav1.ObjectMeta{
Name: "mypod",
Namespace: "default",
UID: "uid-pod",
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: "apps/v1",
Kind: "ReplicaSet",
Name: "missing-rs",
UID: "uid-missing",
Controller: &controller,
},
},
},
}

c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod).Build()

rootRef, err := utils.FindRootControllerRef(context.TODO(), c, pod)
require.NoError(t, err)
require.NotNil(t, rootRef)
require.Equal(t, "missing-rs", rootRef.Name)
require.Equal(t, "ReplicaSet", rootRef.Kind)
})
}

func TestGetPodControllerRef(t *testing.T) {
// Prepare the scheme
sch := runtime.NewScheme()
require.NoError(t, corev1.AddToScheme(sch))
require.NoError(t, appsv1.AddToScheme(sch))
require.NoError(t, batchv1.AddToScheme(sch))

t.Run("pod with no controller returns nil", func(t *testing.T) {
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "mypod",
Namespace: "default",
},
}

c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod).Build()

ref, err := utils.GetPodControllerRef(context.TODO(), c, pod)
require.NoError(t, err)
require.Nil(t, ref)
})

t.Run("pod owned by replicaset owned by deployment returns deployment ref", func(t *testing.T) {
controller := true
deployment := &appsv1.Deployment{
ObjectMeta: metav1.ObjectMeta{
Name: "mydeploy",
Namespace: "default",
UID: "uid-deploy",
},
}

rs := &appsv1.ReplicaSet{
ObjectMeta: metav1.ObjectMeta{
Name: "myrs",
Namespace: "default",
UID: "uid-rs",
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: "apps/v1",
Kind: "Deployment",
Name: "mydeploy",
UID: deployment.UID,
Controller: &controller,
},
},
},
}

pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "mypod",
Namespace: "default",
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: "apps/v1",
Kind: "ReplicaSet",
Name: "myrs",
UID: rs.UID,
Controller: &controller,
},
},
},
}

c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod, rs, deployment).Build()

ref, err := utils.GetPodControllerRef(context.TODO(), c, pod)
require.NoError(t, err)
require.NotNil(t, ref)
require.Equal(t, "mydeploy", ref.Name)
require.Equal(t, "Deployment", ref.Kind)
})

t.Run("pod owned by job owned by cronjob returns cronjob ref", func(t *testing.T) {
controller := true
cronjob := &batchv1.CronJob{
ObjectMeta: metav1.ObjectMeta{
Name: "mycronjob",
Namespace: "default",
UID: "uid-cronjob",
},
}

job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: "myjob",
Namespace: "default",
UID: "uid-job",
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: "batch/v1",
Kind: "CronJob",
Name: "mycronjob",
UID: cronjob.UID,
Controller: &controller,
},
},
},
}

pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "mypod",
Namespace: "default",
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: "batch/v1",
Kind: "Job",
Name: "myjob",
UID: job.UID,
Controller: &controller,
},
},
},
}

c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod, job, cronjob).Build()

ref, err := utils.GetPodControllerRef(context.TODO(), c, pod)
require.NoError(t, err)
require.NotNil(t, ref)
require.Equal(t, "mycronjob", ref.Name)
require.Equal(t, "CronJob", ref.Kind)
})
}
Loading