diff --git a/pkg/controller/usbdevice/usbdevice_claim_controller_test.go b/pkg/controller/usbdevice/usbdevice_claim_controller_test.go index 30e21684..bedd7842 100644 --- a/pkg/controller/usbdevice/usbdevice_claim_controller_test.go +++ b/pkg/controller/usbdevice/usbdevice_claim_controller_test.go @@ -3,6 +3,7 @@ package usbdevice import ( "context" "testing" + "time" "github.com/stretchr/testify/assert" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -14,10 +15,19 @@ import ( "github.com/harvester/pcidevices/pkg/util/fakeclients" ) -type mockUSBDevicePlugin struct{} +type mockUSBDevicePlugin struct { + startTimes int // used to test how many startTimes the Start is called + stopTimes int // used to test how many stopTimes the Stop is called +} -func (m *mockUSBDevicePlugin) Start(_ <-chan struct{}) error { return nil } -func (m *mockUSBDevicePlugin) StopDevicePlugin() error { return nil } +func (m *mockUSBDevicePlugin) Start(_ <-chan struct{}) error { + m.startTimes += 1 + return nil +} +func (m *mockUSBDevicePlugin) StopDevicePlugin() error { + m.stopTimes += 1 + return nil +} var ( mockUsbDevice1 = &v1beta1.USBDevice{ @@ -67,6 +77,120 @@ var ( ) func Test_OnUSBDeviceClaimChanged(t *testing.T) { + testcases := []struct { + fun func() + description string + }{ + { + fun: func() { + client := generateClient() + mockObj := &mockUSBDevicePlugin{} + handler := NewClaimHandler( + fakeclients.USBDeviceCache(client.DevicesV1beta1().USBDevices), + fakeclients.USBDeviceClaimsClient(client.DevicesV1beta1().USBDeviceClaims), + fakeclients.USBDevicesClient(client.DevicesV1beta1().USBDevices), + fakeclients.KubeVirtClient(client.KubevirtV1().KubeVirts), + func(_ string, _ []*deviceplugins.PluginDevices) deviceplugins.USBDevicePluginInterface { + return mockObj + }, + ) + + // Test claim created + _, err := handler.OnUSBDeviceClaimChanged("", mockUsbDeviceClaim1) + assert.NoError(t, err) + time.Sleep(1 * time.Second) + assert.Equal(t, 1, mockObj.startTimes) + + kubevirt, err := client.KubevirtV1().KubeVirts(mockKubeVirt.Namespace).Get(context.Background(), mockKubeVirt.Name, metav1.GetOptions{}) + assert.NoError(t, err) + assert.Equal(t, kubevirtv1.KubeVirtSpec{ + Configuration: kubevirtv1.KubeVirtConfiguration{ + PermittedHostDevices: &kubevirtv1.PermittedHostDevices{ + USB: []kubevirtv1.USBHostDevice{ + { + ResourceName: "kubevirt.io/test-node-0951-1666-001002", + ExternalResourceProvider: true, + Selectors: []kubevirtv1.USBSelector{ + { + Vendor: "0951", + Product: "1666", + }, + }, + }, + }, + }, + }, + }, kubevirt.Spec) + usbDevice, err := client.DevicesV1beta1().USBDevices().Get(context.Background(), mockUsbDevice1.Name, metav1.GetOptions{}) + assert.NoError(t, err) + assert.Equal(t, true, usbDevice.Status.Enabled) + + // Test claim removed + _, err = handler.OnRemove("", mockUsbDeviceClaim1) + assert.NoError(t, err) + usbDevice, err = client.DevicesV1beta1().USBDevices().Get(context.Background(), mockUsbDevice1.Name, metav1.GetOptions{}) + assert.NoError(t, err) + assert.Equal(t, false, usbDevice.Status.Enabled) + kubeVirt, err := client.KubevirtV1().KubeVirts(mockKubeVirt.Namespace).Get(context.Background(), mockKubeVirt.Name, metav1.GetOptions{}) + assert.NoError(t, err) + assert.Equal(t, 0, len(kubeVirt.Spec.Configuration.PermittedHostDevices.USB)) + time.Sleep(1 * time.Second) + assert.Equal(t, 1, mockObj.stopTimes) + }, + description: "General case to create claim and remove claim", + }, + { + fun: func() { + client := generateClient() + mockObj := &mockUSBDevicePlugin{startTimes: 0} + handler := NewClaimHandler( + fakeclients.USBDeviceCache(client.DevicesV1beta1().USBDevices), + fakeclients.USBDeviceClaimsClient(client.DevicesV1beta1().USBDeviceClaims), + fakeclients.USBDevicesClient(client.DevicesV1beta1().USBDevices), + fakeclients.KubeVirtClient(client.KubevirtV1().KubeVirts), + func(_ string, _ []*deviceplugins.PluginDevices) deviceplugins.USBDevicePluginInterface { + return mockObj + }, + ) + + // Test claim created + _, err := handler.OnUSBDeviceClaimChanged("", mockUsbDeviceClaim1) + assert.NoError(t, err) + _, err = handler.OnUSBDeviceClaimChanged("", mockUsbDeviceClaim1) + assert.NoError(t, err) + time.Sleep(1 * time.Second) + assert.Equal(t, 1, mockObj.startTimes) + }, + description: "Case to create two identical claims", + }, + { + fun: func() { + client := generateClient() + handler := NewClaimHandler( + fakeclients.USBDeviceCache(client.DevicesV1beta1().USBDevices), + fakeclients.USBDeviceClaimsClient(client.DevicesV1beta1().USBDeviceClaims), + fakeclients.USBDevicesClient(client.DevicesV1beta1().USBDevices), + fakeclients.KubeVirtClient(client.KubevirtV1().KubeVirts), + mockUSBDevicePluginHelper, + ) + + // Test claim created + mockUsbDeviceClaim1.Name = "non-exist" + _, err := handler.OnUSBDeviceClaimChanged("", mockUsbDeviceClaim1) + assert.NoError(t, err) + }, + description: "Case to remove non-exist claim", + }, + } + + for _, tc := range testcases { + t.Run(tc.description, func(t *testing.T) { + tc.fun() + }) + } +} + +func generateClient() *fake.Clientset { client := fake.NewSimpleClientset(mockUsbDevice1, mockUsbDeviceClaim1, mockKubeVirt) discoverAllowedUSBDevices = func(_ []kubevirtv1.USBHostDevice) map[string][]*deviceplugins.PluginDevices { m := map[string][]*deviceplugins.PluginDevices{} @@ -85,49 +209,5 @@ func Test_OnUSBDeviceClaimChanged(t *testing.T) { return m } - handler := NewClaimHandler( - fakeclients.USBDeviceCache(client.DevicesV1beta1().USBDevices), - fakeclients.USBDeviceClaimsClient(client.DevicesV1beta1().USBDeviceClaims), - fakeclients.USBDevicesClient(client.DevicesV1beta1().USBDevices), - fakeclients.KubeVirtClient(client.KubevirtV1().KubeVirts), - mockUSBDevicePluginHelper, - ) - - // Test claim created - _, err := handler.OnUSBDeviceClaimChanged("", mockUsbDeviceClaim1) - assert.NoError(t, err) - - kubevirt, err := client.KubevirtV1().KubeVirts(mockKubeVirt.Namespace).Get(context.Background(), mockKubeVirt.Name, metav1.GetOptions{}) - assert.NoError(t, err) - assert.Equal(t, kubevirtv1.KubeVirtSpec{ - Configuration: kubevirtv1.KubeVirtConfiguration{ - PermittedHostDevices: &kubevirtv1.PermittedHostDevices{ - USB: []kubevirtv1.USBHostDevice{ - { - ResourceName: "kubevirt.io/test-node-0951-1666-001002", - ExternalResourceProvider: true, - Selectors: []kubevirtv1.USBSelector{ - { - Vendor: "0951", - Product: "1666", - }, - }, - }, - }, - }, - }, - }, kubevirt.Spec) - usbDevice, err := client.DevicesV1beta1().USBDevices().Get(context.Background(), mockUsbDevice1.Name, metav1.GetOptions{}) - assert.NoError(t, err) - assert.Equal(t, true, usbDevice.Status.Enabled) - - // Test claim removed - _, err = handler.OnRemove("", mockUsbDeviceClaim1) - assert.NoError(t, err) - usbDevice, err = client.DevicesV1beta1().USBDevices().Get(context.Background(), mockUsbDevice1.Name, metav1.GetOptions{}) - assert.NoError(t, err) - assert.Equal(t, false, usbDevice.Status.Enabled) - kubeVirt, err := client.KubevirtV1().KubeVirts(mockKubeVirt.Namespace).Get(context.Background(), mockKubeVirt.Name, metav1.GetOptions{}) - assert.NoError(t, err) - assert.Equal(t, 0, len(kubeVirt.Spec.Configuration.PermittedHostDevices.USB)) + return client }