Skip to content

Commit 4bb94d3

Browse files
committed
fix: implement specific logic for the Deployment
1 parent bf57300 commit 4bb94d3

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

internal/webhook/v1/pod_webhook.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ import (
2929
corev1 "k8s.io/api/core/v1"
3030
"k8s.io/apimachinery/pkg/api/equality"
3131
"k8s.io/apimachinery/pkg/api/errors"
32-
"k8s.io/apimachinery/pkg/runtime"
3332
"k8s.io/apimachinery/pkg/util/strategicpatch"
3433
ctrl "sigs.k8s.io/controller-runtime"
3534
"sigs.k8s.io/controller-runtime/pkg/client"
@@ -53,7 +52,7 @@ func SetupPodWebhookWithManager(mgr ctrl.Manager, portAllocator *portallocator.P
5352
webhookServer.Register("/mutate-v1-pod",
5453
&admission.Webhook{
5554
Handler: &TensorFusionPodMutator{
56-
decoder: admission.NewDecoder(runtime.NewScheme()),
55+
decoder: admission.NewDecoder(mgr.GetScheme()),
5756
Client: mgr.GetClient(),
5857
portAllocator: portAllocator,
5958
},

internal/webhook/v1/tf_parser.go

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ import (
1010
"github.com/NexusGPU/tensor-fusion/internal/constants"
1111
"github.com/NexusGPU/tensor-fusion/internal/gpuallocator"
1212
"github.com/NexusGPU/tensor-fusion/internal/utils"
13+
appsv1 "k8s.io/api/apps/v1"
1314
corev1 "k8s.io/api/core/v1"
15+
"k8s.io/apimachinery/pkg/api/errors"
1416
"k8s.io/apimachinery/pkg/api/resource"
17+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1518
"sigs.k8s.io/controller-runtime/pkg/client"
1619
)
1720

@@ -47,8 +50,11 @@ func ParseTensorFusionInfo(
4750
info.EnabledReplicas = &val32
4851
}
4952

50-
// Generate the workload name: if the Pod has no controller, use the Pod's name; otherwise, use the root controller's name.
51-
if controllerRef, err := utils.FindRootControllerRef(ctx, k8sClient, pod); err == nil {
53+
// Generate the workload name:
54+
// If the Pod has no controller, use the Pod's name;
55+
// if it is controlled by a Deployment, return the Deployment's name;
56+
// otherwise, return the name of the first-level controller.
57+
if controllerRef, err := getPodControllerRef(ctx, k8sClient, pod); err == nil {
5258
if controllerRef != nil {
5359
info.WorkloadName = controllerRef.Name
5460
} else {
@@ -254,3 +260,34 @@ func handleDedicatedGPU(pod *corev1.Pod, workloadProfile *tfv1.WorkloadProfile)
254260
workloadProfile.Spec.Resources.Limits.Vram = resource.Vram
255261
return nil
256262
}
263+
264+
func getPodControllerRef(ctx context.Context, c client.Client, pod *corev1.Pod) (*metav1.OwnerReference, error) {
265+
podControllerRef := metav1.GetControllerOf(pod)
266+
if podControllerRef == nil {
267+
return nil, nil
268+
}
269+
270+
switch podControllerRef.Kind {
271+
case "ReplicaSet":
272+
{
273+
// Special handling for Deployment resources
274+
rs := &appsv1.ReplicaSet{}
275+
if err := c.Get(ctx, client.ObjectKey{
276+
Namespace: pod.Namespace,
277+
Name: podControllerRef.Name,
278+
}, rs); err != nil {
279+
if errors.IsNotFound(err) {
280+
return podControllerRef, nil
281+
}
282+
return nil, fmt.Errorf("failed to get ReplicaSet: %w", err)
283+
}
284+
rsContollerRef := metav1.GetControllerOf(rs)
285+
if rsContollerRef != nil && rsContollerRef.Kind == "Deployment" {
286+
// If controlled by a Deployment, return the controllerRef of rs
287+
return rsContollerRef, nil
288+
}
289+
}
290+
}
291+
292+
return podControllerRef, nil
293+
}

0 commit comments

Comments
 (0)