Skip to content

Commit

Permalink
Screen DEP serial from assign profile operations
Browse files Browse the repository at this point in the history
  • Loading branch information
gillespi314 committed Nov 28, 2023
1 parent 8dbe690 commit 629740c
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 9 deletions.
42 changes: 41 additions & 1 deletion server/mdm/apple/apple_mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"encoding/xml"
"fmt"
"os"
"strings"
"text/template"
"time"
Expand Down Expand Up @@ -503,7 +504,10 @@ func (d *DEPService) processDeviceResponse(ctx context.Context, depClient *godep
for profUUID, serials := range profileToSerials {
logger := kitlog.With(d.logger, "profile_uuid", profUUID)
level.Info(logger).Log("msg", "calling DEP client to assign profile", "profile_uuid", profUUID)
apiResp, err := depClient.AssignProfile(ctx, DEPName, profUUID, serials...)

screenedSerials := ApplyDEPScreenToSerials(ctx, logger, profUUID, serials...)

apiResp, err := depClient.AssignProfile(ctx, DEPName, profUUID, screenedSerials...)
if err != nil {
level.Info(logger).Log(
"msg", "assign profile",
Expand All @@ -524,6 +528,42 @@ func (d *DEPService) processDeviceResponse(ctx context.Context, depClient *godep
return nil
}

func ApplyDEPScreenToSerials(ctx context.Context, logger kitlog.Logger, profUUID string, serials ...string) []string {
if len(serials) == 0 {
return serials
}

envExpiry := os.Getenv("FLEET_DEP_SCREEN_EXPIRY")
if envExpiry == "" {
return serials
}
t, err := time.Parse(time.RFC3339, envExpiry)
if err != nil {
level.Error(logger).Log("msg", "parsing dep screen expiry", "err", err)
return serials
}
if time.Now().After(t) {
level.Debug(logger).Log("msg", "dep screen expired", "expiry", t)
return serials
}

envSerial := os.Getenv("FLEET_DEP_SCREEN_SERIAL")
if envSerial == "" {
return serials
}

var filteredSerials []string
for _, serial := range serials {
if serial == envSerial {
level.Info(logger).Log("msg", "applying dep screen", "serial", serial, "profile_uuid", profUUID)
continue
}
filteredSerials = append(filteredSerials, serial)
}

return filteredSerials
}

func (d *DEPService) getProfileUUIDForTeam(ctx context.Context, tmID *uint) (string, error) {
var appleBMTeam *fleet.Team
if tmID != nil {
Expand Down
70 changes: 70 additions & 0 deletions server/mdm/apple/apple_mdm_external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"io"
"net/http"
"net/http/httptest"
"os"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -290,4 +292,72 @@ func TestDEPService_RunAssigner(t *testing.T) {
}
require.ElementsMatch(t, []string{"a", "c"}, serials)
})

t.Run("screened device", func(t *testing.T) {
devices := []godep.Device{
{SerialNumber: "screened-serial", OpType: "added"},
{SerialNumber: "unscreened-serial", OpType: "added"},
}

var assignCalled bool
var shouldScreen atomic.Bool
svc := setupTest(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
encoder := json.NewEncoder(w)
switch r.URL.Path {
case "/session":
_, _ = w.Write([]byte(`{"auth_session_token": "session123"}`))
case "/account":
_, _ = w.Write([]byte(`{"admin_id": "admin123", "org_name": "test_org"}`))
case "/profile":
err := encoder.Encode(godep.ProfileResponse{ProfileUUID: "profile123"})
require.NoError(t, err)
case "/server/devices":
err := encoder.Encode(godep.DeviceResponse{Devices: devices})
require.NoError(t, err)
case "/devices/sync":
err := encoder.Encode(godep.DeviceResponse{Devices: devices})
require.NoError(t, err)
case "/profile/devices":
assignCalled = true

reqBody, err := io.ReadAll(r.Body)
require.NoError(t, err)

var assignReq godep.Profile
err = json.Unmarshal(reqBody, &assignReq)
require.NoError(t, err)
require.Equal(t, assignReq.ProfileUUID, "profile123")
if shouldScreen.Load() {
require.ElementsMatch(t, []string{"unscreened-serial"}, assignReq.Devices)
} else {
require.ElementsMatch(t, []string{"screened-serial", "unscreened-serial"}, assignReq.Devices)
}

_, _ = w.Write([]byte(`{}`))
default:
t.Errorf("unexpected request to %s", r.URL.Path)
}
})

// no screening
os.Unsetenv("FLEET_DEP_SCREEN_SERIAL")
os.Unsetenv("FLEET_DEP_SCREEN_EXPIRY")
shouldScreen.Store(false)
err := svc.RunAssigner(ctx)
require.NoError(t, err)
require.True(t, assignCalled)

// screening
os.Setenv("FLEET_DEP_SCREEN_SERIAL", "screened-serial")
os.Setenv("FLEET_DEP_SCREEN_EXPIRY", time.Now().Add(time.Hour).Format(time.RFC3339))
defer func() {
os.Unsetenv("FLEET_DEP_SCREEN_SERIAL")
os.Unsetenv("FLEET_DEP_SCREEN_EXPIRY")
}()
shouldScreen.Store(true)
err = svc.RunAssigner(ctx)
require.NoError(t, err)
require.True(t, assignCalled)
})
}
18 changes: 11 additions & 7 deletions server/worker/macos_setup_assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ func (m *MacosSetupAssistant) runProfileChanged(ctx context.Context, args macosS
if err != nil {
return ctxerr.Wrap(ctx, err, "list mdm dep serials in team")
}
if len(serials) > 0 {
if _, err := m.DEPClient.AssignProfile(ctx, apple_mdm.DEPName, profUUID, serials...); err != nil {
screenedSerials := apple_mdm.ApplyDEPScreenToSerials(ctx, m.Log, profUUID, serials...)
if len(screenedSerials) > 0 {
if _, err := m.DEPClient.AssignProfile(ctx, apple_mdm.DEPName, profUUID, screenedSerials...); err != nil {
return ctxerr.Wrap(ctx, err, "assign profile")
}
}
Expand Down Expand Up @@ -162,8 +163,9 @@ func (m *MacosSetupAssistant) runProfileDeleted(ctx context.Context, args macosS
if err != nil {
return ctxerr.Wrap(ctx, err, "list mdm dep serials in team")
}
if len(serials) > 0 {
if _, err := m.DEPClient.AssignProfile(ctx, apple_mdm.DEPName, profUUID, serials...); err != nil {
screenedSerials := apple_mdm.ApplyDEPScreenToSerials(ctx, m.Log, profUUID, serials...)
if len(screenedSerials) > 0 {
if _, err := m.DEPClient.AssignProfile(ctx, apple_mdm.DEPName, profUUID, screenedSerials...); err != nil {
return ctxerr.Wrap(ctx, err, "assign profile")
}
}
Expand Down Expand Up @@ -205,9 +207,11 @@ func (m *MacosSetupAssistant) runHostsTransferred(ctx context.Context, args maco
}
}

_, err = m.DEPClient.AssignProfile(ctx, apple_mdm.DEPName, profUUID, args.HostSerialNumbers...)
if err != nil {
return ctxerr.Wrap(ctx, err, "assign profile")
screenedSerials := apple_mdm.ApplyDEPScreenToSerials(ctx, m.Log, profUUID, args.HostSerialNumbers...)
if len(screenedSerials) > 0 {
if _, err := m.DEPClient.AssignProfile(ctx, apple_mdm.DEPName, profUUID, screenedSerials...); err != nil {
return ctxerr.Wrap(ctx, err, "assign profile")
}
}
return nil
}
Expand Down
141 changes: 140 additions & 1 deletion server/worker/macos_setup_assistant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
"time"

Expand All @@ -22,6 +24,143 @@ import (
"github.com/stretchr/testify/require"
)

func TestDEPScreenSerial(t *testing.T) {
ctx := context.Background()
ds := mysql.CreateMySQLDS(t)

hosts := make([]*fleet.Host, 6)
for i := 0; i < len(hosts); i++ {
h, err := ds.NewHost(ctx, &fleet.Host{
Hostname: fmt.Sprintf("test-host%d-name", i),
OsqueryHostID: ptr.String(fmt.Sprintf("osquery-%d", i)),
NodeKey: ptr.String(fmt.Sprintf("nodekey-%d", i)),
UUID: fmt.Sprintf("test-uuid-%d", i),
Platform: "darwin",
HardwareSerial: fmt.Sprintf("serial-%d", i),
})
require.NoError(t, err)
err = ds.UpsertMDMAppleHostDEPAssignments(ctx, []fleet.Host{*h})
require.NoError(t, err)
hosts[i] = h
}

testBMToken := nanodep_client.OAuth1Tokens{
ConsumerKey: "test_consumer",
ConsumerSecret: "test_secret",
AccessToken: "test_access_token",
AccessSecret: "test_access_secret",
AccessTokenExpiry: time.Date(2999, 1, 1, 0, 0, 0, 0, time.UTC),
}

logger := kitlog.NewNopLogger()
depStorage, err := ds.NewMDMAppleDEPStorage(testBMToken)
require.NoError(t, err)
depService := apple_mdm.NewDEPService(ds, depStorage, logger)
macosJob := &MacosSetupAssistant{
Datastore: ds,
Log: logger,
DEPService: depService,
DEPClient: apple_mdm.NewDEPClient(depStorage, ds, logger),
}

var shouldScreen atomic.Bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
encoder := json.NewEncoder(w)
switch r.URL.Path {
case "/session":
err := encoder.Encode(map[string]string{"auth_session_token": "auth123"})
require.NoError(t, err)

case "/profile":
var reqProf godep.Profile
b, err := io.ReadAll(r.Body)
require.NoError(t, err)
err = json.Unmarshal(b, &reqProf)
require.NoError(t, err)

// use the profile name as profile uuid, and append "+sso" if it was
// registered with the sso url (end-user auth enabled).
profUUID := reqProf.ProfileName
if strings.HasSuffix(reqProf.ConfigurationWebURL, "/mdm/sso") {
profUUID += "+sso"
}
err = encoder.Encode(godep.ProfileResponse{ProfileUUID: profUUID})
require.NoError(t, err)

case "/profile/devices":
var reqProf godep.Profile
b, err := io.ReadAll(r.Body)
require.NoError(t, err)
err = json.Unmarshal(b, &reqProf)
require.NoError(t, err)

for _, d := range reqProf.Devices {
if shouldScreen.Load() {
require.NotEqual(t, "serial-0", d)
}
}
_, _ = w.Write([]byte(`{}`))

default:
t.Errorf("unexpected request: %s", r.URL.Path)
}
}))
defer srv.Close()
err = depStorage.StoreConfig(ctx, apple_mdm.DEPName, &nanodep_client.Config{BaseURL: srv.URL})
require.NoError(t, err)

w := NewWorker(ds, logger)
w.Register(macosJob)

checkJob := func() {
// enqueue a regenerate all and process the jobs
err = QueueMacosSetupAssistantJob(ctx, ds, logger, MacosSetupAssistantUpdateAllProfiles, nil)
require.NoError(t, err)
err = w.ProcessJobs(ctx)
require.NoError(t, err)
// no remaining jobs to process
pending, err := ds.GetQueuedJobs(ctx, 10)
for _, p := range pending {
t.Logf("pending job: %s", p.Name)
var args macosSetupAssistantArgs
err := json.Unmarshal(*p.Args, &args)
require.NoError(t, err)
t.Logf(" args: %v", args)
}

require.NoError(t, err)
require.Empty(t, pending)
}

for _, tc := range []struct {
name string
serial string
expiry string
shouldScreen bool
}{
{"no env vars", "", "", false},
{"valid serial and future expiry ", "serial-0", time.Now().Add(time.Hour).Format(time.RFC3339), true},
{"valid serial and past expiry", "serial-0", time.Now().Add(-time.Hour).Format(time.RFC3339), false},
{"valid serial and empty expiry", "serial-0", "", false},
{"valid serial and invalid expiry", "serial-0", "invalid", false},
{"empty serial and future expiry", "", time.Now().Add(time.Hour).Format(time.RFC3339), false},
{"unknown serial and future expiry", "unknown", time.Now().Add(time.Hour).Format(time.RFC3339), false},
} {
t.Run(tc.name, func(t *testing.T) {
os.Setenv("FLEET_DEP_SCREEN_SERIAL", tc.serial)
os.Setenv("FLEET_DEP_SCREEN_EXPIRY", tc.expiry)
defer func() {
os.Unsetenv("FLEET_DEP_SCREEN_SERIAL")
os.Unsetenv("FLEET_DEP_SCREEN_EXPIRY")
}()

shouldScreen.Store(tc.shouldScreen)
checkJob()
})
}
}

func TestMacosSetupAssistant(t *testing.T) {
ctx := context.Background()
ds := mysql.CreateMySQLDS(t)
Expand Down Expand Up @@ -58,7 +197,7 @@ func TestMacosSetupAssistant(t *testing.T) {
err = ds.AddHostsToTeam(ctx, &tm2.ID, []uint{hosts[4].ID, hosts[5].ID})
require.NoError(t, err)

var testBMToken = nanodep_client.OAuth1Tokens{
testBMToken := nanodep_client.OAuth1Tokens{
ConsumerKey: "test_consumer",
ConsumerSecret: "test_secret",
AccessToken: "test_access_token",
Expand Down

0 comments on commit 629740c

Please sign in to comment.