diff --git a/changes/23832-select-nano_enrollment_queue b/changes/23832-select-nano_enrollment_queue new file mode 100644 index 000000000000..5ae116326e56 --- /dev/null +++ b/changes/23832-select-nano_enrollment_queue @@ -0,0 +1 @@ +Improved MDM `SELECT FROM nano_enrollment_queue` MySQL query performance, including calling it on DB reader much of the time. diff --git a/pkg/mdm/mdmtest/apple.go b/pkg/mdm/mdmtest/apple.go index 46314d0354a1..e6cbbb0e1a06 100644 --- a/pkg/mdm/mdmtest/apple.go +++ b/pkg/mdm/mdmtest/apple.go @@ -685,6 +685,23 @@ func (c *TestAppleMDMClient) Acknowledge(cmdUUID string) (*mdm.Command, error) { return c.sendAndDecodeCommandResponse(payload) } +// NotNow sends a NotNow message to the MDM server. +// The cmdUUID is the UUID of the command to reference. +// +// The server can signal back with either a command to run +// or an empty (nil, nil) response body to end the communication +// (i.e. no commands to run). +func (c *TestAppleMDMClient) NotNow(cmdUUID string) (*mdm.Command, error) { + payload := map[string]any{ + "Status": "NotNow", + "Topic": "com.apple.mgmt.External." + c.UUID, + "UDID": c.UUID, + "EnrollmentID": "testenrollmentid-" + c.UUID, + "CommandUUID": cmdUUID, + } + return c.sendAndDecodeCommandResponse(payload) +} + func (c *TestAppleMDMClient) AcknowledgeDeviceInformation(udid, cmdUUID, deviceName, productName string) (*mdm.Command, error) { payload := map[string]any{ "Status": "Acknowledged", diff --git a/server/datastore/mysql/apple_mdm.go b/server/datastore/mysql/apple_mdm.go index 9c8e522988a5..1ae6e3fef8df 100644 --- a/server/datastore/mysql/apple_mdm.go +++ b/server/datastore/mysql/apple_mdm.go @@ -2748,7 +2748,15 @@ func (ds *Datastore) BulkUpsertMDMAppleHostProfiles(ctx context.Context, payload strings.TrimSuffix(valuePart, ","), ) - _, err := ds.writer(ctx).ExecContext(ctx, stmt, args...) + // We need to run with retry due to deadlocks. + // The INSERT/ON DUPLICATE KEY UPDATE pattern is prone to deadlocks when multiple + // threads are modifying nearby rows. That's because this statement uses gap locks. + // When two transactions acquire the same gap lock, they may deadlock. + // Two simultaneous transactions may happen when cron job runs and the user is updating via the UI at the same time. + err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { + _, err := tx.ExecContext(ctx, stmt, args...) + return err + }) return err } diff --git a/server/datastore/mysql/locks_test.go b/server/datastore/mysql/locks_test.go index e0476ea63a5d..d45ed46cee47 100644 --- a/server/datastore/mysql/locks_test.go +++ b/server/datastore/mysql/locks_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/fleetdm/fleet/v4/server" + "github.com/fleetdm/fleet/v4/server/fleet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -89,7 +90,7 @@ const ( ) //nolint:unused // used in skipped tests -func getMySQLServer(t *testing.T, r dbReader) mysqlServer { +func getMySQLServer(t *testing.T, r fleet.DBReader) mysqlServer { row := r.QueryRowxContext(context.Background(), "SELECT VERSION()") var version string require.NoError(t, row.Scan(&version)) diff --git a/server/datastore/mysql/mysql.go b/server/datastore/mysql/mysql.go index 31486e87d11a..356dabb9c643 100644 --- a/server/datastore/mysql/mysql.go +++ b/server/datastore/mysql/mysql.go @@ -45,19 +45,10 @@ const ( // Matches all non-word and '-' characters for replacement var columnCharsRegexp = regexp.MustCompile(`[^\w-.]`) -// dbReader is an interface that defines the methods required for reads. -type dbReader interface { - sqlx.QueryerContext - sqlx.PreparerContext - - Close() error - Rebind(string) string -} - // Datastore is an implementation of fleet.Datastore interface backed by // MySQL type Datastore struct { - replica dbReader // so it cannot be used to perform writes + replica fleet.DBReader // so it cannot be used to perform writes primary *sqlx.DB logger log.Logger @@ -115,7 +106,7 @@ type Datastore struct { // reader returns the DB instance to use for read-only statements, which is the // replica unless the primary has been explicitly required via // ctxdb.RequirePrimary. -func (ds *Datastore) reader(ctx context.Context) dbReader { +func (ds *Datastore) reader(ctx context.Context) fleet.DBReader { if ctxdb.IsPrimaryRequired(ctx) { return ds.primary } @@ -518,7 +509,7 @@ func (ds *Datastore) MigrateData(ctx context.Context) error { func (ds *Datastore) loadMigrations( ctx context.Context, writer *sql.DB, - reader dbReader, + reader fleet.DBReader, ) (tableRecs []int64, dataRecs []int64, err error) { // We need to run the following to trigger the creation of the migration status tables. _, err = tables.MigrationClient.GetDBVersion(writer) diff --git a/server/datastore/mysql/nanomdm_storage.go b/server/datastore/mysql/nanomdm_storage.go index 2f2ea2848238..585bc68949dd 100644 --- a/server/datastore/mysql/nanomdm_storage.go +++ b/server/datastore/mysql/nanomdm_storage.go @@ -54,6 +54,7 @@ func (ds *Datastore) NewMDMAppleMDMStorage() (*NanoMDMStorage, error) { s, err := nanomdm_mysql.New( nanomdm_mysql.WithDB(ds.primary.DB), nanomdm_mysql.WithLogger(nanoMDMLogAdapter{logger: ds.logger}), + nanomdm_mysql.WithReaderFunc(ds.reader), ) if err != nil { return nil, err @@ -73,6 +74,7 @@ func (ds *Datastore) NewTestMDMAppleMDMStorage(asyncCap int, asyncInterval time. s, err := nanomdm_mysql.New( nanomdm_mysql.WithDB(ds.primary.DB), nanomdm_mysql.WithLogger(nanoMDMLogAdapter{logger: ds.logger}), + nanomdm_mysql.WithReaderFunc(ds.reader), nanomdm_mysql.WithAsyncLastSeen(asyncCap, asyncInterval), ) if err != nil { diff --git a/server/fleet/db.go b/server/fleet/db.go index 39737fdf6340..1df8253a2f0a 100644 --- a/server/fleet/db.go +++ b/server/fleet/db.go @@ -1,5 +1,7 @@ package fleet +import "github.com/jmoiron/sqlx" + // DBLock represents a database transaction lock information as returned // by datastore.DBLocks. type DBLock struct { @@ -10,3 +12,12 @@ type DBLock struct { BlockingThread uint64 `db:"blocking_thread" json:"blocking_thread"` BlockingQuery *string `db:"blocking_query" json:"blocking_query,omitempty"` } + +// DBReader is an interface that defines the methods required for reads. +type DBReader interface { + sqlx.QueryerContext + sqlx.PreparerContext + + Close() error + Rebind(string) string +} diff --git a/server/mdm/nanomdm/service/nanomdm/service.go b/server/mdm/nanomdm/service/nanomdm/service.go index 3685a483b255..a0c52c1dde18 100644 --- a/server/mdm/nanomdm/service/nanomdm/service.go +++ b/server/mdm/nanomdm/service/nanomdm/service.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" + "github.com/fleetdm/fleet/v4/server/contexts/ctxdb" "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm" "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/service" "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/storage" @@ -244,6 +245,10 @@ func (s *Service) CommandAndReportResults(r *mdm.Request, results *mdm.CommandRe "error_chain", results.ErrorChain, ) } + if results.Status != "Idle" { + // If the host is not idle, we use primary DB since we just wrote results of previous command. + ctxdb.RequirePrimary(r.Context, true) + } cmd, err := s.store.RetrieveNextCommand(r, results.Status == "NotNow") if err != nil { return nil, fmt.Errorf("retrieving next command: %w", err) diff --git a/server/mdm/nanomdm/storage/mysql/mysql.go b/server/mdm/nanomdm/storage/mysql/mysql.go index a6ec35250705..628ab42c965f 100644 --- a/server/mdm/nanomdm/storage/mysql/mysql.go +++ b/server/mdm/nanomdm/storage/mysql/mysql.go @@ -10,6 +10,7 @@ import ( "os" "time" + "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/cryptoutil" "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm" "github.com/jmoiron/sqlx" @@ -29,6 +30,7 @@ type MySQLStorage struct { db *sql.DB rm bool asyncLastSeen *asyncLastSeen + reader func(ctx context.Context) fleet.DBReader } type config struct { @@ -39,10 +41,17 @@ type config struct { rm bool asyncCap int asyncInterval time.Duration + reader func(ctx context.Context) fleet.DBReader } type Option func(*config) +func WithReaderFunc(readerFunc func(ctx context.Context) fleet.DBReader) Option { + return func(c *config) { + c.reader = readerFunc + } +} + func WithLogger(logger log.Logger) Option { return func(c *config) { c.logger = logger @@ -102,6 +111,13 @@ func New(opts ...Option) (*MySQLStorage, error) { } mysqlStore := &MySQLStorage{db: cfg.db, logger: cfg.logger, rm: cfg.rm} + if cfg.reader == nil { + mysqlStore.reader = func(ctx context.Context) fleet.DBReader { + return sqlx.NewDb(mysqlStore.db, "mysql") + } + } else { + mysqlStore.reader = cfg.reader + } if v := os.Getenv("FLEET_DISABLE_ASYNC_NANO_LAST_SEEN"); v != "1" { asyncLastSeen := newAsyncLastSeen(cfg.asyncInterval, cfg.asyncCap, mysqlStore.updateLastSeenBatch) diff --git a/server/mdm/nanomdm/storage/mysql/queue.go b/server/mdm/nanomdm/storage/mysql/queue.go index 28a2c058b44a..bce893253d2d 100644 --- a/server/mdm/nanomdm/storage/mysql/queue.go +++ b/server/mdm/nanomdm/storage/mysql/queue.go @@ -7,7 +7,9 @@ import ( "fmt" "strings" + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm" + "github.com/google/uuid" ) func enqueue(ctx context.Context, tx *sql.Tx, ids []string, cmd *mdm.Command) error { @@ -22,15 +24,30 @@ func enqueue(ctx context.Context, tx *sql.Tx, ids []string, cmd *mdm.Command) er if err != nil { return err } - query := `INSERT INTO nano_enrollment_queue (id, command_uuid) VALUES (?, ?)` - query += strings.Repeat(", (?, ?)", len(ids)-1) - args := make([]interface{}, len(ids)*2) - for i, id := range ids { - args[i*2] = id - args[i*2+1] = cmd.CommandUUID + const mySQLPlaceholderLimit = 65536 - 1 + const placeholdersPerInsert = 2 + const batchSize = mySQLPlaceholderLimit / placeholdersPerInsert + for i := 0; i < len(ids); i += batchSize { + end := i + batchSize + if end > len(ids) { + end = len(ids) + } + idsBatch := ids[i:end] + + // Process batch + query := `INSERT INTO nano_enrollment_queue (id, command_uuid) VALUES (?, ?)` + query += strings.Repeat(", (?, ?)", len(idsBatch)-1) + args := make([]interface{}, len(idsBatch)*placeholdersPerInsert) + for i, id := range idsBatch { + args[i*2] = id + args[i*2+1] = cmd.CommandUUID + } + _, err = tx.ExecContext(ctx, query+";", args...) + if err != nil { + return err + } } - _, err = tx.ExecContext(ctx, query+";", args...) - return err + return nil } func (m *MySQLStorage) EnqueueCommand(ctx context.Context, ids []string, cmd *mdm.Command) (map[string]error, error) { @@ -168,22 +185,38 @@ UPDATE func (m *MySQLStorage) RetrieveNextCommand(r *mdm.Request, skipNotNow bool) (*mdm.Command, error) { command := new(mdm.Command) - err := m.db.QueryRowContext( - r.Context, ` + id := "?" + var args []interface{} + // Validate the ID to avoid SQL injection. + // This performance optimization eliminates the prepare statement for this frequent query. + // Eventually, we should use binary storage for id (UUID). + if err := uuid.Validate(r.ID); err == nil { + id = "'" + r.ID + "'" + } else { + err = ctxerr.Wrap(r.Context, err, "device ID is not a valid UUID: %s", r.ID) + m.logger.Info("msg", "device ID is not a UUID", "device_id", r.ID, "err", err) + // Handle the error by sending it to Redis to be included in aggregated statistics. + // Before switching UUID to use binary storage, we should ensure that this error rate is low/none. + ctxerr.Handle(r.Context, err) + args = append(args, r.ID) + } + err := m.reader(r.Context).QueryRowxContext( + r.Context, fmt.Sprintf( + // The query should use the ANTIJOIN (NOT EXISTS) optimization on the nano_command_results table. + ` SELECT c.command_uuid, c.request_type, c.command FROM nano_enrollment_queue AS q INNER JOIN nano_commands AS c ON q.command_uuid = c.command_uuid LEFT JOIN nano_command_results r - ON r.command_uuid = q.command_uuid AND r.id = q.id -WHERE q.id = ? + ON r.command_uuid = q.command_uuid AND r.id = q.id AND (r.status != 'NotNow' OR %t) +WHERE q.id = %s AND q.active = 1 - AND (r.status IS NULL OR (r.status = 'NotNow' AND NOT ?)) + AND r.status IS NULL ORDER BY q.priority DESC, q.created_at -LIMIT 1;`, - r.ID, skipNotNow, +LIMIT 1;`, skipNotNow, id), args..., ).Scan(&command.CommandUUID, &command.Command.RequestType, &command.Raw) if errors.Is(err, sql.ErrNoRows) { return nil, nil diff --git a/server/service/integration_mdm_test.go b/server/service/integration_mdm_test.go index ea1d37ecb951..e87fa4ec8d87 100644 --- a/server/service/integration_mdm_test.go +++ b/server/service/integration_mdm_test.go @@ -10035,6 +10035,95 @@ func (s *integrationMDMTestSuite) TestAPNsPushCron() { require.Len(t, recordedPushes, 0) } +func (s *integrationMDMTestSuite) TestAPNsPushWithNotNow() { + t := s.T() + ctx := context.Background() + + // macOS host, MDM on + _, macDevice := createHostThenEnrollMDM(s.ds, s.server.URL, t) + // windows host, MDM on + createWindowsHostThenEnrollMDM(s.ds, s.server.URL, t) + // linux and darwin, MDM off + createOrbitEnrolledHost(t, "linux", "linux_host", s.ds) + createOrbitEnrolledHost(t, "darwin", "mac_not_enrolled", s.ds) + + // we're going to modify this mock, make sure we restore its default + originalPushMock := s.pushProvider.PushFunc + defer func() { s.pushProvider.PushFunc = originalPushMock }() + + var recordedPushes []*mdm.Push + var mu sync.Mutex + s.pushProvider.PushFunc = func(ctx context.Context, pushes []*mdm.Push) (map[string]*push.Response, error) { + mu.Lock() + defer mu.Unlock() + recordedPushes = pushes + return mockSuccessfulPush(ctx, pushes) + } + + // trigger the reconciliation schedule + err := ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger) + require.NoError(t, err) + require.Len(t, recordedPushes, 1) + recordedPushes = nil + + // Flush any existing profiles. + cmd, err := macDevice.Idle() + require.NoError(t, err) + for { + if cmd == nil { + break + } + t.Logf("Received: %s %s", cmd.CommandUUID, cmd.Command.RequestType) + cmd, err = macDevice.Acknowledge(cmd.CommandUUID) + require.NoError(t, err) + } + + // Load new profiles + s.Do("POST", "/api/v1/fleet/mdm/profiles/batch", batchSetMDMProfilesRequest{Profiles: []fleet.MDMProfileBatchPayload{ + {Name: "N1", Contents: mobileconfigForTest("N1", "I1")}, + {Name: "N2", Contents: syncMLForTest("./Foo/Bar")}, + }}, http.StatusNoContent) + + // trigger the reconciliation schedule + err = ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger) + require.NoError(t, err) + require.Len(t, recordedPushes, 1) + recordedPushes = nil + + // The cron to trigger pushes sends a new push request each time it runs. + err = SendPushesToPendingDevices(ctx, s.ds, s.mdmCommander, s.logger) + require.NoError(t, err) + require.Len(t, recordedPushes, 1) + recordedPushes = nil + + // device sends 'NotNow' + cmd, err = macDevice.Idle() + require.NoError(t, err) + require.NotNil(t, cmd) + cmd, err = macDevice.NotNow(cmd.CommandUUID) + require.NoError(t, err) + assert.Nil(t, cmd) + + // A 'NotNow' command will not trigger a new push. Device is expected to check in again when conditions change. + err = ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger) + require.NoError(t, err) + require.Len(t, recordedPushes, 0) + recordedPushes = nil + + // device acknowledges the commands + cmd, err = macDevice.Idle() + require.NoError(t, err) + require.NotNil(t, cmd) + cmd, err = macDevice.Acknowledge(cmd.CommandUUID) + require.NoError(t, err) + assert.Nil(t, cmd) + + // no more pushes are enqueued + err = SendPushesToPendingDevices(ctx, s.ds, s.mdmCommander, s.logger) + require.NoError(t, err) + assert.Zero(t, recordedPushes) +} + func (s *integrationMDMTestSuite) TestMDMRequestWithoutCerts() { t := s.T() res := s.DoRawNoAuth("PUT", "/mdm/apple/mdm", nil, http.StatusBadRequest)