Skip to content

Commit

Permalink
Improvements for select next Apple MDM command query. (#24128)
Browse files Browse the repository at this point in the history
#23832 

[Loadtest
report](https://docs.google.com/document/d/1HafECokrZ3jnzRskxMtJwp4k1E2uBTbO9vfKEUtyykI/edit?tab=t.0)

# Checklist for submitter
- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/Committing-Changes.md#changes-files)
for more information.
- [x] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [x] Added/updated tests
- [x] Manual QA for all new/changed functionality
  • Loading branch information
getvictor authored Dec 5, 2024
1 parent 614446e commit afebfde
Show file tree
Hide file tree
Showing 11 changed files with 203 additions and 29 deletions.
1 change: 1 addition & 0 deletions changes/23832-select-nano_enrollment_queue
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved MDM `SELECT FROM nano_enrollment_queue` MySQL query performance, including calling it on DB reader much of the time.
17 changes: 17 additions & 0 deletions pkg/mdm/mdmtest/apple.go
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,23 @@ func (c *TestAppleMDMClient) Acknowledge(cmdUUID string) (*mdm.Command, error) {
return c.sendAndDecodeCommandResponse(payload)
}

// NotNow sends a NotNow message to the MDM server.
// The cmdUUID is the UUID of the command to reference.
//
// The server can signal back with either a command to run
// or an empty (nil, nil) response body to end the communication
// (i.e. no commands to run).
func (c *TestAppleMDMClient) NotNow(cmdUUID string) (*mdm.Command, error) {
payload := map[string]any{
"Status": "NotNow",
"Topic": "com.apple.mgmt.External." + c.UUID,
"UDID": c.UUID,
"EnrollmentID": "testenrollmentid-" + c.UUID,
"CommandUUID": cmdUUID,
}
return c.sendAndDecodeCommandResponse(payload)
}

func (c *TestAppleMDMClient) AcknowledgeDeviceInformation(udid, cmdUUID, deviceName, productName string) (*mdm.Command, error) {
payload := map[string]any{
"Status": "Acknowledged",
Expand Down
10 changes: 9 additions & 1 deletion server/datastore/mysql/apple_mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2748,7 +2748,15 @@ func (ds *Datastore) BulkUpsertMDMAppleHostProfiles(ctx context.Context, payload
strings.TrimSuffix(valuePart, ","),
)

_, err := ds.writer(ctx).ExecContext(ctx, stmt, args...)
// We need to run with retry due to deadlocks.
// The INSERT/ON DUPLICATE KEY UPDATE pattern is prone to deadlocks when multiple
// threads are modifying nearby rows. That's because this statement uses gap locks.
// When two transactions acquire the same gap lock, they may deadlock.
// Two simultaneous transactions may happen when cron job runs and the user is updating via the UI at the same time.
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, stmt, args...)
return err
})
return err
}

Expand Down
3 changes: 2 additions & 1 deletion server/datastore/mysql/locks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -89,7 +90,7 @@ const (
)

//nolint:unused // used in skipped tests
func getMySQLServer(t *testing.T, r dbReader) mysqlServer {
func getMySQLServer(t *testing.T, r fleet.DBReader) mysqlServer {
row := r.QueryRowxContext(context.Background(), "SELECT VERSION()")
var version string
require.NoError(t, row.Scan(&version))
Expand Down
15 changes: 3 additions & 12 deletions server/datastore/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,10 @@ const (
// Matches all non-word and '-' characters for replacement
var columnCharsRegexp = regexp.MustCompile(`[^\w-.]`)

// dbReader is an interface that defines the methods required for reads.
type dbReader interface {
sqlx.QueryerContext
sqlx.PreparerContext

Close() error
Rebind(string) string
}

// Datastore is an implementation of fleet.Datastore interface backed by
// MySQL
type Datastore struct {
replica dbReader // so it cannot be used to perform writes
replica fleet.DBReader // so it cannot be used to perform writes
primary *sqlx.DB

logger log.Logger
Expand Down Expand Up @@ -115,7 +106,7 @@ type Datastore struct {
// reader returns the DB instance to use for read-only statements, which is the
// replica unless the primary has been explicitly required via
// ctxdb.RequirePrimary.
func (ds *Datastore) reader(ctx context.Context) dbReader {
func (ds *Datastore) reader(ctx context.Context) fleet.DBReader {
if ctxdb.IsPrimaryRequired(ctx) {
return ds.primary
}
Expand Down Expand Up @@ -518,7 +509,7 @@ func (ds *Datastore) MigrateData(ctx context.Context) error {
func (ds *Datastore) loadMigrations(
ctx context.Context,
writer *sql.DB,
reader dbReader,
reader fleet.DBReader,
) (tableRecs []int64, dataRecs []int64, err error) {
// We need to run the following to trigger the creation of the migration status tables.
_, err = tables.MigrationClient.GetDBVersion(writer)
Expand Down
2 changes: 2 additions & 0 deletions server/datastore/mysql/nanomdm_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func (ds *Datastore) NewMDMAppleMDMStorage() (*NanoMDMStorage, error) {
s, err := nanomdm_mysql.New(
nanomdm_mysql.WithDB(ds.primary.DB),
nanomdm_mysql.WithLogger(nanoMDMLogAdapter{logger: ds.logger}),
nanomdm_mysql.WithReaderFunc(ds.reader),
)
if err != nil {
return nil, err
Expand All @@ -73,6 +74,7 @@ func (ds *Datastore) NewTestMDMAppleMDMStorage(asyncCap int, asyncInterval time.
s, err := nanomdm_mysql.New(
nanomdm_mysql.WithDB(ds.primary.DB),
nanomdm_mysql.WithLogger(nanoMDMLogAdapter{logger: ds.logger}),
nanomdm_mysql.WithReaderFunc(ds.reader),
nanomdm_mysql.WithAsyncLastSeen(asyncCap, asyncInterval),
)
if err != nil {
Expand Down
11 changes: 11 additions & 0 deletions server/fleet/db.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package fleet

import "github.com/jmoiron/sqlx"

// DBLock represents a database transaction lock information as returned
// by datastore.DBLocks.
type DBLock struct {
Expand All @@ -10,3 +12,12 @@ type DBLock struct {
BlockingThread uint64 `db:"blocking_thread" json:"blocking_thread"`
BlockingQuery *string `db:"blocking_query" json:"blocking_query,omitempty"`
}

// DBReader is an interface that defines the methods required for reads.
type DBReader interface {
sqlx.QueryerContext
sqlx.PreparerContext

Close() error
Rebind(string) string
}
5 changes: 5 additions & 0 deletions server/mdm/nanomdm/service/nanomdm/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"

"github.com/fleetdm/fleet/v4/server/contexts/ctxdb"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/service"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/storage"
Expand Down Expand Up @@ -244,6 +245,10 @@ func (s *Service) CommandAndReportResults(r *mdm.Request, results *mdm.CommandRe
"error_chain", results.ErrorChain,
)
}
if results.Status != "Idle" {
// If the host is not idle, we use primary DB since we just wrote results of previous command.
ctxdb.RequirePrimary(r.Context, true)
}
cmd, err := s.store.RetrieveNextCommand(r, results.Status == "NotNow")
if err != nil {
return nil, fmt.Errorf("retrieving next command: %w", err)
Expand Down
16 changes: 16 additions & 0 deletions server/mdm/nanomdm/storage/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os"
"time"

"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/cryptoutil"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
"github.com/jmoiron/sqlx"
Expand All @@ -29,6 +30,7 @@ type MySQLStorage struct {
db *sql.DB
rm bool
asyncLastSeen *asyncLastSeen
reader func(ctx context.Context) fleet.DBReader
}

type config struct {
Expand All @@ -39,10 +41,17 @@ type config struct {
rm bool
asyncCap int
asyncInterval time.Duration
reader func(ctx context.Context) fleet.DBReader
}

type Option func(*config)

func WithReaderFunc(readerFunc func(ctx context.Context) fleet.DBReader) Option {
return func(c *config) {
c.reader = readerFunc
}
}

func WithLogger(logger log.Logger) Option {
return func(c *config) {
c.logger = logger
Expand Down Expand Up @@ -102,6 +111,13 @@ func New(opts ...Option) (*MySQLStorage, error) {
}

mysqlStore := &MySQLStorage{db: cfg.db, logger: cfg.logger, rm: cfg.rm}
if cfg.reader == nil {
mysqlStore.reader = func(ctx context.Context) fleet.DBReader {
return sqlx.NewDb(mysqlStore.db, "mysql")
}
} else {
mysqlStore.reader = cfg.reader
}

if v := os.Getenv("FLEET_DISABLE_ASYNC_NANO_LAST_SEEN"); v != "1" {
asyncLastSeen := newAsyncLastSeen(cfg.asyncInterval, cfg.asyncCap, mysqlStore.updateLastSeenBatch)
Expand Down
63 changes: 48 additions & 15 deletions server/mdm/nanomdm/storage/mysql/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"fmt"
"strings"

"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
"github.com/google/uuid"
)

func enqueue(ctx context.Context, tx *sql.Tx, ids []string, cmd *mdm.Command) error {
Expand All @@ -22,15 +24,30 @@ func enqueue(ctx context.Context, tx *sql.Tx, ids []string, cmd *mdm.Command) er
if err != nil {
return err
}
query := `INSERT INTO nano_enrollment_queue (id, command_uuid) VALUES (?, ?)`
query += strings.Repeat(", (?, ?)", len(ids)-1)
args := make([]interface{}, len(ids)*2)
for i, id := range ids {
args[i*2] = id
args[i*2+1] = cmd.CommandUUID
const mySQLPlaceholderLimit = 65536 - 1
const placeholdersPerInsert = 2
const batchSize = mySQLPlaceholderLimit / placeholdersPerInsert
for i := 0; i < len(ids); i += batchSize {
end := i + batchSize
if end > len(ids) {
end = len(ids)
}
idsBatch := ids[i:end]

// Process batch
query := `INSERT INTO nano_enrollment_queue (id, command_uuid) VALUES (?, ?)`
query += strings.Repeat(", (?, ?)", len(idsBatch)-1)
args := make([]interface{}, len(idsBatch)*placeholdersPerInsert)
for i, id := range idsBatch {
args[i*2] = id
args[i*2+1] = cmd.CommandUUID
}
_, err = tx.ExecContext(ctx, query+";", args...)
if err != nil {
return err
}
}
_, err = tx.ExecContext(ctx, query+";", args...)
return err
return nil
}

func (m *MySQLStorage) EnqueueCommand(ctx context.Context, ids []string, cmd *mdm.Command) (map[string]error, error) {
Expand Down Expand Up @@ -168,22 +185,38 @@ UPDATE

func (m *MySQLStorage) RetrieveNextCommand(r *mdm.Request, skipNotNow bool) (*mdm.Command, error) {
command := new(mdm.Command)
err := m.db.QueryRowContext(
r.Context, `
id := "?"
var args []interface{}
// Validate the ID to avoid SQL injection.
// This performance optimization eliminates the prepare statement for this frequent query.
// Eventually, we should use binary storage for id (UUID).
if err := uuid.Validate(r.ID); err == nil {
id = "'" + r.ID + "'"
} else {
err = ctxerr.Wrap(r.Context, err, "device ID is not a valid UUID: %s", r.ID)
m.logger.Info("msg", "device ID is not a UUID", "device_id", r.ID, "err", err)
// Handle the error by sending it to Redis to be included in aggregated statistics.
// Before switching UUID to use binary storage, we should ensure that this error rate is low/none.
ctxerr.Handle(r.Context, err)
args = append(args, r.ID)
}
err := m.reader(r.Context).QueryRowxContext(
r.Context, fmt.Sprintf(
// The query should use the ANTIJOIN (NOT EXISTS) optimization on the nano_command_results table.
`
SELECT c.command_uuid, c.request_type, c.command
FROM nano_enrollment_queue AS q
INNER JOIN nano_commands AS c
ON q.command_uuid = c.command_uuid
LEFT JOIN nano_command_results r
ON r.command_uuid = q.command_uuid AND r.id = q.id
WHERE q.id = ?
ON r.command_uuid = q.command_uuid AND r.id = q.id AND (r.status != 'NotNow' OR %t)
WHERE q.id = %s
AND q.active = 1
AND (r.status IS NULL OR (r.status = 'NotNow' AND NOT ?))
AND r.status IS NULL
ORDER BY
q.priority DESC,
q.created_at
LIMIT 1;`,
r.ID, skipNotNow,
LIMIT 1;`, skipNotNow, id), args...,
).Scan(&command.CommandUUID, &command.Command.RequestType, &command.Raw)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
Expand Down
89 changes: 89 additions & 0 deletions server/service/integration_mdm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10035,6 +10035,95 @@ func (s *integrationMDMTestSuite) TestAPNsPushCron() {
require.Len(t, recordedPushes, 0)
}

func (s *integrationMDMTestSuite) TestAPNsPushWithNotNow() {
t := s.T()
ctx := context.Background()

// macOS host, MDM on
_, macDevice := createHostThenEnrollMDM(s.ds, s.server.URL, t)
// windows host, MDM on
createWindowsHostThenEnrollMDM(s.ds, s.server.URL, t)
// linux and darwin, MDM off
createOrbitEnrolledHost(t, "linux", "linux_host", s.ds)
createOrbitEnrolledHost(t, "darwin", "mac_not_enrolled", s.ds)

// we're going to modify this mock, make sure we restore its default
originalPushMock := s.pushProvider.PushFunc
defer func() { s.pushProvider.PushFunc = originalPushMock }()

var recordedPushes []*mdm.Push
var mu sync.Mutex
s.pushProvider.PushFunc = func(ctx context.Context, pushes []*mdm.Push) (map[string]*push.Response, error) {
mu.Lock()
defer mu.Unlock()
recordedPushes = pushes
return mockSuccessfulPush(ctx, pushes)
}

// trigger the reconciliation schedule
err := ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger)
require.NoError(t, err)
require.Len(t, recordedPushes, 1)
recordedPushes = nil

// Flush any existing profiles.
cmd, err := macDevice.Idle()
require.NoError(t, err)
for {
if cmd == nil {
break
}
t.Logf("Received: %s %s", cmd.CommandUUID, cmd.Command.RequestType)
cmd, err = macDevice.Acknowledge(cmd.CommandUUID)
require.NoError(t, err)
}

// Load new profiles
s.Do("POST", "/api/v1/fleet/mdm/profiles/batch", batchSetMDMProfilesRequest{Profiles: []fleet.MDMProfileBatchPayload{
{Name: "N1", Contents: mobileconfigForTest("N1", "I1")},
{Name: "N2", Contents: syncMLForTest("./Foo/Bar")},
}}, http.StatusNoContent)

// trigger the reconciliation schedule
err = ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger)
require.NoError(t, err)
require.Len(t, recordedPushes, 1)
recordedPushes = nil

// The cron to trigger pushes sends a new push request each time it runs.
err = SendPushesToPendingDevices(ctx, s.ds, s.mdmCommander, s.logger)
require.NoError(t, err)
require.Len(t, recordedPushes, 1)
recordedPushes = nil

// device sends 'NotNow'
cmd, err = macDevice.Idle()
require.NoError(t, err)
require.NotNil(t, cmd)
cmd, err = macDevice.NotNow(cmd.CommandUUID)
require.NoError(t, err)
assert.Nil(t, cmd)

// A 'NotNow' command will not trigger a new push. Device is expected to check in again when conditions change.
err = ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger)
require.NoError(t, err)
require.Len(t, recordedPushes, 0)
recordedPushes = nil

// device acknowledges the commands
cmd, err = macDevice.Idle()
require.NoError(t, err)
require.NotNil(t, cmd)
cmd, err = macDevice.Acknowledge(cmd.CommandUUID)
require.NoError(t, err)
assert.Nil(t, cmd)

// no more pushes are enqueued
err = SendPushesToPendingDevices(ctx, s.ds, s.mdmCommander, s.logger)
require.NoError(t, err)
assert.Zero(t, recordedPushes)
}

func (s *integrationMDMTestSuite) TestMDMRequestWithoutCerts() {
t := s.T()
res := s.DoRawNoAuth("PUT", "/mdm/apple/mdm", nil, http.StatusBadRequest)
Expand Down

0 comments on commit afebfde

Please sign in to comment.