diff --git a/lib/services/local/users_test.go b/lib/services/local/users_test.go index e0f97d17d5110..e55b6d6d7932f 100644 --- a/lib/services/local/users_test.go +++ b/lib/services/local/users_test.go @@ -595,33 +595,40 @@ func TestIdentityService_GetMFADevices_SSO(t *testing.T) { tests := []struct { name string connectorRef *types.ConnectorRef - expectSSODevice bool + expectSSODevice *types.SSOMFADevice }{ { name: "non-sso user", connectorRef: nil, - expectSSODevice: false, + expectSSODevice: nil, }, { name: "sso user mfa disabled", connectorRef: &types.ConnectorRef{ Type: "saml", ID: samlConnectorNoMFA.GetName(), }, - expectSSODevice: false, }, { name: "saml user", connectorRef: &types.ConnectorRef{ Type: "saml", ID: samlConnector.GetName(), }, - expectSSODevice: true, + expectSSODevice: &types.SSOMFADevice{ + ConnectorType: "saml", + ConnectorId: samlConnector.GetName(), + DisplayName: samlConnector.GetDisplay(), + }, }, { name: "oidc user", connectorRef: &types.ConnectorRef{ Type: "oidc", ID: oidcConnector.GetName(), }, - expectSSODevice: true, + expectSSODevice: &types.SSOMFADevice{ + ConnectorType: "oidc", + ConnectorId: oidcConnector.GetName(), + DisplayName: oidcConnector.GetDisplay(), + }, }, } for _, test := range tests { @@ -638,15 +645,13 @@ func TestIdentityService_GetMFADevices_SSO(t *testing.T) { devs, err := identity.GetMFADevices(ctx, "alice", true /* withSecrets */) require.NoError(t, err) - if !test.expectSSODevice { + if test.expectSSODevice == nil { assert.Empty(t, devs) return } - expectSSODevice, err := types.NewMFADevice(test.connectorRef.ID, test.connectorRef.ID, clock.Now().UTC(), &types.MFADevice_Sso{ - Sso: &types.SSOMFADevice{ - ConnectorId: test.connectorRef.ID, - ConnectorType: test.connectorRef.Type, - }, + + expectSSODevice, err := types.NewMFADevice(test.expectSSODevice.ConnectorId, test.expectSSODevice.DisplayName, clock.Now().UTC(), &types.MFADevice_Sso{ + Sso: test.expectSSODevice, }) require.NoError(t, err) assert.Equal(t, []*types.MFADevice{expectSSODevice}, devs)