Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: retrigger automatic installations after label scope changes (#25163) #25172

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion ee/server/service/software_installers.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,41 @@ func (svc *Service) UpdateSoftwareInstaller(ctx context.Context, payload *fleet.
payload.SelfService = &existingInstaller.SelfService
}

// Get the hosts that are NOT in label scope currently (before the update happens)
var hostsNotInScope map[uint]struct{}
if dirty["Labels"] {
hostsNotInScope, err = svc.ds.GetExcludedHostIDMapForSoftwareInstaller(ctx, payload.InstallerID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "getting hosts not in scope for installer")
}
}

if err := svc.ds.SaveInstallerUpdates(ctx, payload); err != nil {
return nil, ctxerr.Wrap(ctx, err, "saving installer updates")
}

if dirty["Labels"] {
// Get the hosts that are now IN label scope (after the update)
hostsInScope, err := svc.ds.GetIncludedHostIDMapForSoftwareInstaller(ctx, payload.InstallerID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "getting hosts in scope for installer")
}

var hostsToClear []uint
for id := range hostsInScope {
if _, ok := hostsNotInScope[id]; ok {
// it was not in scope but now it is, so we should clear policy status
hostsToClear = append(hostsToClear, id)
}
}

// We clear the policy status here because otherwise the policy automation machinery
// won't pick this up and the software won't install.
if err := svc.ds.ClearAutoInstallPolicyStatusForHosts(ctx, payload.InstallerID, hostsToClear); err != nil {
return nil, ctxerr.Wrap(ctx, err, "failed to clear auto install policy status for host")
}
}

// if we're updating anything other than self-service, we cancel pending installs/uninstalls,
// and if we're updating the package we reset counts. This is run in its own transaction internally
// for consistency, but independent of the installer update query as the main update should stick
Expand Down Expand Up @@ -484,7 +515,8 @@ func (svc *Service) UpdateSoftwareInstaller(ctx context.Context, payload *fleet.
}

func (svc *Service) validateEmbeddedSecretsOnScript(ctx context.Context, scriptName string, script *string,
argErr *fleet.InvalidArgumentError) *fleet.InvalidArgumentError {
argErr *fleet.InvalidArgumentError,
) *fleet.InvalidArgumentError {
if script != nil {
if errScript := svc.ds.ValidateEmbeddedSecrets(ctx, []string{*script}); errScript != nil {
if argErr != nil {
Expand Down
28 changes: 28 additions & 0 deletions server/datastore/mysql/policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,34 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee
return nil
}

func (ds *Datastore) ClearAutoInstallPolicyStatusForHosts(ctx context.Context, installerID uint, hostIDs []uint) error {
if len(hostIDs) == 0 {
return nil
}

stmt := `
UPDATE
policies p
JOIN policy_membership pm ON pm.policy_id = p.id
SET
passes = NULL
WHERE
p.software_installer_id = ?
AND pm.host_id IN (?)
`

stmt, args, err := sqlx.In(stmt, installerID, hostIDs)
if err != nil {
return ctxerr.Wrap(ctx, err, "building in statement for clearing auto install policy status")
}

if _, err := ds.writer(ctx).ExecContext(ctx, stmt, args...); err != nil {
return ctxerr.Wrap(ctx, err, "clearing auto install policy status")
}

return nil
}

func (ds *Datastore) ListGlobalPolicies(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error) {
return listPoliciesDB(ctx, ds.reader(ctx), nil, opts)
}
Expand Down
81 changes: 81 additions & 0 deletions server/datastore/mysql/policies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func TestPolicies(t *testing.T) {
{"TestPoliciesTeamPoliciesWithScript", testTeamPoliciesWithScript},
{"TeamPoliciesNoTeam", testTeamPoliciesNoTeam},
{"TestPoliciesBySoftwareTitleID", testPoliciesBySoftwareTitleID},
{"TestClearAutoInstallPolicyStatusForHost", testClearAutoInstallPolicyStatusForHost},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
Expand Down Expand Up @@ -5371,3 +5372,83 @@ func testPoliciesBySoftwareTitleID(t *testing.T, ds *Datastore) {
require.NoError(t, err)
require.Len(t, policies, 0)
}

func testClearAutoInstallPolicyStatusForHost(t *testing.T, ds *Datastore) {
ctx := context.Background()

user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1" + t.Name()})
require.NoError(t, err)

// create a regular policy
policy1 := newTestPolicy(t, ds, user1, "policy 1"+t.Name(), "darwin", &team1.ID)

// create an automatic install policy
policy2 := newTestPolicy(t, ds, user1, "policy 2"+t.Name(), "darwin", &team1.ID)
installer, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)

installer1ID, _, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
InstallerFile: installer,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
TeamID: &team1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
policy2.SoftwareInstallerID = ptr.Uint(installer1ID)
err = ds.SavePolicy(context.Background(), policy2, false, false)
require.NoError(t, err)

// create a host
host, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String(uuid.New().String()),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String(uuid.New().String()),
UUID: uuid.New().String(),
Hostname: "host" + t.Name(),
TeamID: &team1.ID,
Platform: "darwin",
})
require.NoError(t, err)

// record a policy run for both policies
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{
policy1.ID: ptr.Bool(true),
policy2.ID: ptr.Bool(false), // software isn't installed on host, so Fleet should install it
}, time.Now(), false)
require.NoError(t, err)

hostPolicies, err := ds.ListPoliciesForHost(ctx, host)
require.NoError(t, err)
require.Len(t, hostPolicies, 2)
sort.Slice(hostPolicies, func(i, j int) bool {
return hostPolicies[i].ID < hostPolicies[j].ID
})
require.Equal(t, hostPolicies[0].Response, "pass")
require.Equal(t, hostPolicies[1].Response, "fail")

// clear status for automatic install policy
err = ds.ClearAutoInstallPolicyStatusForHosts(ctx, installer1ID, []uint{host.ID})
require.NoError(t, err)

// the status should be NULL for the automatic install policy but not the "regular" one
hostPolicies, err = ds.ListPoliciesForHost(ctx, host)
require.NoError(t, err)
require.Len(t, hostPolicies, 2)
sort.Slice(hostPolicies, func(i, j int) bool {
return hostPolicies[i].ID < hostPolicies[j].ID
})
require.Equal(t, hostPolicies[0].Response, "pass")
require.Empty(t, hostPolicies[1].Response)
}
106 changes: 106 additions & 0 deletions server/datastore/mysql/software_installers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1755,3 +1755,109 @@ func (ds *Datastore) IsSoftwareInstallerLabelScoped(ctx context.Context, install

return res, nil
}

const labelScopedFilter = `
SELECT
1
FROM (
-- no labels
SELECT
0 AS count_installer_labels,
0 AS count_host_labels,
0 AS count_host_updated_after_labels
WHERE NOT EXISTS ( SELECT 1 FROM software_installer_labels sil WHERE sil.software_installer_id = ?)

UNION

-- include any
SELECT
COUNT(*) AS count_installer_labels,
COUNT(lm.label_id) AS count_host_labels,
0 AS count_host_updated_after_labels
FROM
software_installer_labels sil
LEFT OUTER JOIN label_membership lm ON lm.label_id = sil.label_id
AND lm.host_id = h.id
WHERE
sil.software_installer_id = ?
AND sil.exclude = 0
HAVING
count_installer_labels > 0
AND count_host_labels > 0

UNION

-- exclude any, ignore software that depends on labels created
-- _after_ the label_updated_at timestamp of the host (because
-- we don't have results for that label yet, the host may or may
-- not be a member).
SELECT
COUNT(*) AS count_installer_labels,
COUNT(lm.label_id) AS count_host_labels,
SUM(
CASE WHEN lbl.created_at IS NOT NULL
AND(
SELECT
label_updated_at FROM hosts
WHERE
id = 1) >= lbl.created_at THEN
1
ELSE
0
END) AS count_host_updated_after_labels
FROM
software_installer_labels sil
LEFT OUTER JOIN labels lbl ON lbl.id = sil.label_id
LEFT OUTER JOIN label_membership lm ON lm.label_id = sil.label_id
AND lm.host_id = h.id
WHERE
sil.software_installer_id = ?
AND sil.exclude = 1
HAVING
count_installer_labels > 0
AND count_installer_labels = count_host_updated_after_labels
AND count_host_labels = 0) t`

func (ds *Datastore) GetIncludedHostIDMapForSoftwareInstaller(ctx context.Context, installerID uint) (map[uint]struct{}, error) {
stmt := fmt.Sprintf(`SELECT
h.id
FROM
hosts h
WHERE
EXISTS (%s)
`, labelScopedFilter)

var hostIDs []uint
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &hostIDs, stmt, installerID, installerID, installerID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "listing hosts included in software installer scope")
}

res := make(map[uint]struct{}, len(hostIDs))
for _, id := range hostIDs {
res[id] = struct{}{}
}

return res, nil
}

func (ds *Datastore) GetExcludedHostIDMapForSoftwareInstaller(ctx context.Context, installerID uint) (map[uint]struct{}, error) {
stmt := fmt.Sprintf(`SELECT
h.id
FROM
hosts h
WHERE
NOT EXISTS (%s)
`, labelScopedFilter)

var hostIDs []uint
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &hostIDs, stmt, installerID, installerID, installerID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "listing hosts excluded from software installer scope")
}

res := make(map[uint]struct{}, len(hostIDs))
for _, id := range hostIDs {
res[id] = struct{}{}
}

return res, nil
}
8 changes: 8 additions & 0 deletions server/datastore/mysql/software_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5321,6 +5321,10 @@ func testListHostSoftwareWithLabelScoping(t *testing.T, ds *Datastore) {
require.NoError(t, err)
require.True(t, scoped)

hostsInScope, err := ds.GetIncludedHostIDMapForSoftwareInstaller(ctx, installerID1)
require.NoError(t, err)
require.Contains(t, hostsInScope, host.ID)

label1, err := ds.NewLabel(ctx, &fleet.Label{Name: "label1" + t.Name()})
require.NoError(t, err)

Expand All @@ -5343,6 +5347,10 @@ func testListHostSoftwareWithLabelScoping(t *testing.T, ds *Datastore) {
require.NoError(t, err)
require.Empty(t, software)

hostsNotInScope, err := ds.GetExcludedHostIDMapForSoftwareInstaller(ctx, installerID1)
require.NoError(t, err)
require.Contains(t, hostsNotInScope, host.ID)

// installer1 should be out of scope since the label is "exclude any"
scoped, err = ds.IsSoftwareInstallerLabelScoped(ctx, installerID1, host.ID)
require.NoError(t, err)
Expand Down
12 changes: 12 additions & 0 deletions server/fleet/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -1706,6 +1706,18 @@ type Datastore interface {
// Software installers
//

// GetIncludedHostIDMapForSoftwareInstaller gets the set of hosts that are targeted/in scope for the
// given software installer, based label membership.
GetIncludedHostIDMapForSoftwareInstaller(ctx context.Context, installerID uint) (map[uint]struct{}, error)

// GetExcludedHostIDMapForSoftwareInstaller gets the set of hosts that are NOT targeted/in scope for the
// given software installer, based label membership.
GetExcludedHostIDMapForSoftwareInstaller(ctx context.Context, installerID uint) (map[uint]struct{}, error)

// ClearAutoInstallPolicyStatusForHosts clears out the status of the policy related to the given
// software installer for all the given hosts.
ClearAutoInstallPolicyStatusForHosts(ctx context.Context, installerID uint, hostIDs []uint) error

// GetSoftwareInstallDetails returns details required to fetch and
// run software installers
GetSoftwareInstallDetails(ctx context.Context, executionId string) (*SoftwareInstallDetails, error)
Expand Down
36 changes: 36 additions & 0 deletions server/mock/datastore_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,12 @@ type WipeHostViaWindowsMDMFunc func(ctx context.Context, host *fleet.Host, cmd *

type UpdateHostLockWipeStatusFromAppleMDMResultFunc func(ctx context.Context, hostUUID string, cmdUUID string, requestType string, succeeded bool) error

type GetIncludedHostIDMapForSoftwareInstallerFunc func(ctx context.Context, installerID uint) (map[uint]struct{}, error)

type GetExcludedHostIDMapForSoftwareInstallerFunc func(ctx context.Context, installerID uint) (map[uint]struct{}, error)

type ClearAutoInstallPolicyStatusForHostsFunc func(ctx context.Context, installerID uint, hostIDs []uint) error

type GetSoftwareInstallDetailsFunc func(ctx context.Context, executionId string) (*fleet.SoftwareInstallDetails, error)

type ListPendingSoftwareInstallsFunc func(ctx context.Context, hostID uint) ([]string, error)
Expand Down Expand Up @@ -2772,6 +2778,15 @@ type DataStore struct {
UpdateHostLockWipeStatusFromAppleMDMResultFunc UpdateHostLockWipeStatusFromAppleMDMResultFunc
UpdateHostLockWipeStatusFromAppleMDMResultFuncInvoked bool

GetIncludedHostIDMapForSoftwareInstallerFunc GetIncludedHostIDMapForSoftwareInstallerFunc
GetIncludedHostIDMapForSoftwareInstallerFuncInvoked bool

GetExcludedHostIDMapForSoftwareInstallerFunc GetExcludedHostIDMapForSoftwareInstallerFunc
GetExcludedHostIDMapForSoftwareInstallerFuncInvoked bool

ClearAutoInstallPolicyStatusForHostsFunc ClearAutoInstallPolicyStatusForHostsFunc
ClearAutoInstallPolicyStatusForHostsFuncInvoked bool

GetSoftwareInstallDetailsFunc GetSoftwareInstallDetailsFunc
GetSoftwareInstallDetailsFuncInvoked bool

Expand Down Expand Up @@ -6636,6 +6651,27 @@ func (s *DataStore) UpdateHostLockWipeStatusFromAppleMDMResult(ctx context.Conte
return s.UpdateHostLockWipeStatusFromAppleMDMResultFunc(ctx, hostUUID, cmdUUID, requestType, succeeded)
}

func (s *DataStore) GetIncludedHostIDMapForSoftwareInstaller(ctx context.Context, installerID uint) (map[uint]struct{}, error) {
s.mu.Lock()
s.GetIncludedHostIDMapForSoftwareInstallerFuncInvoked = true
s.mu.Unlock()
return s.GetIncludedHostIDMapForSoftwareInstallerFunc(ctx, installerID)
}

func (s *DataStore) GetExcludedHostIDMapForSoftwareInstaller(ctx context.Context, installerID uint) (map[uint]struct{}, error) {
s.mu.Lock()
s.GetExcludedHostIDMapForSoftwareInstallerFuncInvoked = true
s.mu.Unlock()
return s.GetExcludedHostIDMapForSoftwareInstallerFunc(ctx, installerID)
}

func (s *DataStore) ClearAutoInstallPolicyStatusForHosts(ctx context.Context, installerID uint, hostIDs []uint) error {
s.mu.Lock()
s.ClearAutoInstallPolicyStatusForHostsFuncInvoked = true
s.mu.Unlock()
return s.ClearAutoInstallPolicyStatusForHostsFunc(ctx, installerID, hostIDs)
}

func (s *DataStore) GetSoftwareInstallDetails(ctx context.Context, executionId string) (*fleet.SoftwareInstallDetails, error) {
s.mu.Lock()
s.GetSoftwareInstallDetailsFuncInvoked = true
Expand Down
Loading
Loading