diff --git a/server/datastore/mysql/disk_encryption.go b/server/datastore/mysql/disk_encryption.go index 710771776c92..96f278a68ab9 100644 --- a/server/datastore/mysql/disk_encryption.go +++ b/server/datastore/mysql/disk_encryption.go @@ -31,12 +31,8 @@ func (ds *Datastore) SetOrUpdateHostDiskEncryptionKey(ctx context.Context, host return ctxerr.Wrap(ctx, err, "getting existing key, if present") } - // TODO: add salt here - var keySlot *uint - var encryptedBase64Salt string // We use the same timestamp for base and archive tables so that it can be used as an additional debug tool if needed. - createdAt := time.Now().UTC() - var incomingKey = encryptionKey{Base: encryptedBase64Key, Salt: encryptedBase64Salt, KeySlot: keySlot, CreatedAt: createdAt} + var incomingKey = encryptionKey{Base: encryptedBase64Key, CreatedAt: time.Now().UTC()} err = ds.archiveHostDiskEncryptionKey(ctx, host, incomingKey, existingKey) if err != nil { return ctxerr.Wrap(ctx, err, "archiving key") @@ -47,7 +43,7 @@ func (ds *Datastore) SetOrUpdateHostDiskEncryptionKey(ctx context.Context, host INSERT INTO host_disk_encryption_keys (host_id, base64_encrypted, client_error, decryptable, created_at) VALUES - (?, ?, ?, ?, ?)`, host.ID, encryptedBase64Key, clientError, decryptable, createdAt) + (?, ?, ?, ?, ?)`, host.ID, incomingKey.Base, clientError, decryptable, incomingKey.CreatedAt) if err == nil { return nil } @@ -73,9 +69,9 @@ UPDATE host_disk_encryption_keys SET base64_encrypted = ?, client_error = ? WHERE host_id = ? -`, encryptedBase64Key, decryptable, encryptedBase64Key, clientError, host.ID) +`, incomingKey.Base, decryptable, incomingKey.Base, clientError, host.ID) if err != nil { - return ctxerr.Wrap(ctx, err, "inserting key") + return ctxerr.Wrap(ctx, err, "updating key") } return nil } @@ -101,7 +97,7 @@ func (ds *Datastore) archiveHostDiskEncryptionKey(ctx context.Context, host *fle (incomingKey.Salt != "" && existingKey.Salt != incomingKey.Salt) { const insertKeyIntoArchiveStmt = ` INSERT INTO host_disk_encryption_keys_archive (host_id, hardware_serial, base64_encrypted, base64_encrypted_salt, key_slot, created_at) -VALUES (?, ?, ?, ?, ?)` +VALUES (?, ?, ?, ?, ?, ?)` _, err := ds.writer(ctx).ExecContext(ctx, insertKeyIntoArchiveStmt, host.ID, host.HardwareSerial, incomingKey.Base, incomingKey.Salt, incomingKey.KeySlot, incomingKey.CreatedAt) @@ -118,19 +114,55 @@ func (ds *Datastore) SaveLUKSData(ctx context.Context, host *fleet.Host, encrypt return errors.New("passphrase and salt must be set") } - _, err := ds.writer(ctx).ExecContext(ctx, ` + existingKey, err := ds.getExistingHostDiskEncryptionKey(ctx, host) + if err != nil { + return ctxerr.Wrap(ctx, err, "getting existing LUKS key, if present") + } + + // We use the same timestamp for base and archive tables so that it can be used as an additional debug tool if needed. + var incomingKey = encryptionKey{Base: encryptedBase64Passphrase, Salt: encryptedBase64Salt, KeySlot: &keySlot, + CreatedAt: time.Now().UTC()} + err = ds.archiveHostDiskEncryptionKey(ctx, host, incomingKey, existingKey) + if err != nil { + return ctxerr.Wrap(ctx, err, "archiving LUKS key") + } + + if existingKey.NotFound { + _, err = ds.writer(ctx).ExecContext(ctx, ` INSERT INTO host_disk_encryption_keys - (host_id, base64_encrypted, base64_encrypted_salt, key_slot, client_error, decryptable) + (host_id, base64_encrypted, base64_encrypted_salt, key_slot, decryptable, created_at) VALUES - (?, ?, ?, ?, '', TRUE) -ON DUPLICATE KEY UPDATE + (?, ?, ?, ?, TRUE, ?)`, host.ID, incomingKey.Base, incomingKey.Salt, incomingKey.KeySlot, incomingKey.CreatedAt) + if err == nil { + return nil + } + var mysqlErr *mysql.MySQLError + switch { + case errors.As(err, &mysqlErr) && mysqlErr.Number == 1062: + level.Error(ds.logger).Log("msg", "Primary key already exists in LUKS host_disk_encryption_keys. Falling back to update", + "host_id", + host) + // This should never happen unless there is a bug in the code or an infra issue (like huge replication lag). + default: + return ctxerr.Wrap(ctx, err, "inserting LUKS key") + } + } + + _, err = ds.writer(ctx).ExecContext(ctx, ` +UPDATE host_disk_encryption_keys SET + /* if the key has changed, set decrypted to its initial value so it can be calculated again if necessary (if null) */ decryptable = TRUE, - base64_encrypted = VALUES(base64_encrypted), - base64_encrypted_salt = VALUES(base64_encrypted_salt), - key_slot = VALUES(key_slot), + base64_encrypted = ?, + base64_encrypted_salt = ?, + key_slot = ?, client_error = '' -`, host.ID, encryptedBase64Passphrase, encryptedBase64Salt, keySlot) - return err +WHERE host_id = ? +`, incomingKey.Base, incomingKey.Salt, incomingKey.KeySlot, host.ID) + if err != nil { + return ctxerr.Wrap(ctx, err, "updating LUKS key") + } + return nil + } func (ds *Datastore) IsHostPendingEscrow(ctx context.Context, hostID uint) bool {