diff --git a/pkg/controller/usbdevice/register.go b/pkg/controller/usbdevice/register.go index f632844..abc4ddc 100644 --- a/pkg/controller/usbdevice/register.go +++ b/pkg/controller/usbdevice/register.go @@ -9,12 +9,6 @@ import ( "github.com/harvester/pcidevices/pkg/config" ) -const ( - KubeVirtNamespace = "harvester-system" - KubeVirtResource = "kubevirt" - KubeVirtResourcePrefix = "kubevirt.io/" -) - func Register(ctx context.Context, management *config.FactoryManager) error { usbDeviceCtrl := management.DeviceFactory.Devices().V1beta1().USBDevice() usbDeviceClaimCtrl := management.DeviceFactory.Devices().V1beta1().USBDeviceClaim() diff --git a/pkg/controller/usbdevice/usbdevice_claim_controller.go b/pkg/controller/usbdevice/usbdevice_claim_controller.go index b4b4789..c21df24 100644 --- a/pkg/controller/usbdevice/usbdevice_claim_controller.go +++ b/pkg/controller/usbdevice/usbdevice_claim_controller.go @@ -14,6 +14,7 @@ import ( "github.com/harvester/pcidevices/pkg/deviceplugins" ctldevicerv1beta1 "github.com/harvester/pcidevices/pkg/generated/controllers/devices.harvesterhci.io/v1beta1" ctlkubevirtv1 "github.com/harvester/pcidevices/pkg/generated/controllers/kubevirt.io/v1" + "github.com/harvester/pcidevices/pkg/util" ) type DevClaimHandler struct { @@ -69,7 +70,7 @@ func (h *DevClaimHandler) OnUSBDeviceClaimChanged(_ string, usbDeviceClaim *v1be h.lock.Lock() defer h.lock.Unlock() - virt, err := h.virtClient.Get(KubeVirtNamespace, KubeVirtResource, metav1.GetOptions{}) + virt, err := h.virtClient.Get(util.KubeVirtNamespace, util.KubeVirtResource, metav1.GetOptions{}) if err != nil { logrus.Errorf("failed to get kubevirt: %v", err) return usbDeviceClaim, err @@ -138,7 +139,7 @@ func (h *DevClaimHandler) OnRemove(_ string, claim *v1beta1.USBDeviceClaim) (*v1 h.lock.Lock() defer h.lock.Unlock() - virt, err := h.virtClient.Get(KubeVirtNamespace, KubeVirtResource, metav1.GetOptions{}) + virt, err := h.virtClient.Get(util.KubeVirtNamespace, util.KubeVirtResource, metav1.GetOptions{}) if err != nil { fmt.Println(err) return nil, err diff --git a/pkg/controller/usbdevice/usbdevice_claim_controller_test.go b/pkg/controller/usbdevice/usbdevice_claim_controller_test.go index 570b80f..14e6713 100644 --- a/pkg/controller/usbdevice/usbdevice_claim_controller_test.go +++ b/pkg/controller/usbdevice/usbdevice_claim_controller_test.go @@ -11,6 +11,7 @@ import ( "github.com/harvester/pcidevices/pkg/apis/devices.harvesterhci.io/v1beta1" "github.com/harvester/pcidevices/pkg/generated/clientset/versioned/fake" + "github.com/harvester/pcidevices/pkg/util" "github.com/harvester/pcidevices/pkg/util/fakeclients" ) @@ -35,7 +36,7 @@ var ( mockKubeVirt = &kubevirtv1.KubeVirt{ ObjectMeta: metav1.ObjectMeta{ Name: "kubevirt", - Namespace: KubeVirtNamespace, + Namespace: util.KubeVirtNamespace, }, Spec: kubevirtv1.KubeVirtSpec{}, } diff --git a/pkg/controller/usbdevice/usbdevice_controller.go b/pkg/controller/usbdevice/usbdevice_controller.go index 9fb37e5..8868770 100644 --- a/pkg/controller/usbdevice/usbdevice_controller.go +++ b/pkg/controller/usbdevice/usbdevice_controller.go @@ -17,6 +17,7 @@ import ( "github.com/harvester/pcidevices/pkg/apis/devices.harvesterhci.io/v1beta1" "github.com/harvester/pcidevices/pkg/deviceplugins" ctldevicerv1vbeta1 "github.com/harvester/pcidevices/pkg/generated/controllers/devices.harvesterhci.io/v1beta1" + "github.com/harvester/pcidevices/pkg/util" "github.com/harvester/pcidevices/pkg/util/gousb" "github.com/harvester/pcidevices/pkg/util/gousb/usbid" ) @@ -266,5 +267,5 @@ func isStatusChanged(existed *v1beta1.USBDevice, localUSBDevice *deviceplugins.U } func resourceName(name string) string { - return fmt.Sprintf("%s%s", KubeVirtResourcePrefix, name) + return fmt.Sprintf("%s%s", util.KubeVirtResourcePrefix, name) } diff --git a/pkg/util/constant.go b/pkg/util/constant.go new file mode 100644 index 0000000..a532825 --- /dev/null +++ b/pkg/util/constant.go @@ -0,0 +1,7 @@ +package util + +const ( + KubeVirtNamespace = "harvester-system" + KubeVirtResource = "kubevirt" + KubeVirtResourcePrefix = "kubevirt.io/" +) diff --git a/pkg/webhook/vm_validatation.go b/pkg/webhook/vm_validatation.go index 21609bd..c98b3a6 100644 --- a/pkg/webhook/vm_validatation.go +++ b/pkg/webhook/vm_validatation.go @@ -70,41 +70,52 @@ func (vmValidator *vmDeviceHostValidator) Update(_ *types.Request, _ runtime.Obj func (vmValidator *vmDeviceHostValidator) validateDevicesFromSameNodes(vmObj *kubevirtv1.VirtualMachine) error { var nodeName string - errorMsgFormat := "device %s/%s is not on the same node in VirtualMachine.Spec.Template.Spec.Domain.Devices.HostDevices %s" - - for _, device := range vmObj.Spec.Template.Spec.Domain.Devices.HostDevices { - usb, err := vmValidator.usbCache.Get(device.Name) - if err != nil { - if !apierrors.IsNotFound(err) { - return err - } - } - if nodeName == "" && usb != nil { - nodeName = usb.Status.NodeName - continue + for number, device := range vmObj.Spec.Template.Spec.Domain.Devices.HostDevices { + if err := vmValidator.validateDevice(device, &nodeName, number, vmObj.Name); err != nil { + return err } + } - pci, err := vmValidator.pciCache.Get(device.Name) - if err != nil { - if !apierrors.IsNotFound(err) { - return err - } - } + return nil +} - if nodeName == "" && pci != nil { - nodeName = pci.Status.NodeName - continue - } +func (vmValidator *vmDeviceHostValidator) validateDevice(device kubevirtv1.HostDevice, nodeName *string, number int, vmName string) error { + errorMsgFormat := "hostDevices[].name %s/%s is not on the same node in VirtualMachine.Spec.Template.Spec.Domain.Devices.HostDevices %s" - if pci != nil && pci.Status.NodeName != nodeName { - return fmt.Errorf(errorMsgFormat, "pcidevice", pci.Name, vmObj.Name) + usb, err := vmValidator.usbCache.Get(device.Name) + if err != nil && !apierrors.IsNotFound(err) { + return err + } + + if *nodeName == "" && usb != nil { + if usb.Status.ResourceName != device.DeviceName { + return fmt.Errorf("hostDevices[%d].DeviceName %s is not found in USBDevice", number, device.DeviceName) } + *nodeName = usb.Status.NodeName + return nil + } - if usb != nil && usb.Status.NodeName != nodeName { - return fmt.Errorf(errorMsgFormat, "usbdevice", usb.Name, vmObj.Name) + pci, err := vmValidator.pciCache.Get(device.Name) + if err != nil && !apierrors.IsNotFound(err) { + return err + } + + if *nodeName == "" && pci != nil { + if pci.Status.ResourceName != device.DeviceName { + return fmt.Errorf("hostDevices[%d].DeviceName %s is not found in PCIDevice", number, device.DeviceName) } + *nodeName = pci.Status.NodeName + return nil } - return nil + if pci != nil && pci.Status.NodeName != *nodeName { + return fmt.Errorf(errorMsgFormat, "pcidevice", pci.Name, vmName) + } + + if usb != nil && usb.Status.NodeName != *nodeName { + return fmt.Errorf(errorMsgFormat, "usbdevice", usb.Name, vmName) + } + + return fmt.Errorf("hostDevices[%d].name %s is not found in USBDevice or PCIDevice", number, device.Name) }