Skip to content

Commit

Permalink
Implement deletion of pending scripts related to policy
Browse files Browse the repository at this point in the history
  • Loading branch information
mna committed Jan 7, 2025
1 parent 6fa982b commit eb96e09
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 30 deletions.
27 changes: 22 additions & 5 deletions server/datastore/mysql/scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,12 @@ func (ds *Datastore) deletePendingHostScriptExecutionsForPolicy(ctx context.Cont
globalOrTeamID = *teamID
}

// TODO(mna): must delete from the upcoming queue
deleteStmt := fmt.Sprintf(`
deletePendingFunc := func(stmt string, args ...any) error {
_, err := ds.writer(ctx).ExecContext(ctx, stmt, args...)
return ctxerr.Wrap(ctx, err, "delete pending host script executions for policy")
}

deleteHSRStmt := fmt.Sprintf(`
DELETE FROM
host_script_results
WHERE
Expand All @@ -573,9 +577,22 @@ func (ds *Datastore) deletePendingHostScriptExecutionsForPolicy(ctx context.Cont
`, whereFilterPendingScript)

seconds := int(constants.MaxServerWaitTime.Seconds())
_, err := ds.writer(ctx).ExecContext(ctx, deleteStmt, policyID, globalOrTeamID, seconds)
if err != nil {
return ctxerr.Wrap(ctx, err, "delete pending host script executions for policy")
if err := deletePendingFunc(deleteHSRStmt, policyID, globalOrTeamID, seconds); err != nil {
return err
}

deleteUAStmt := `
DELETE FROM
upcoming_activities
WHERE
policy_id = ? AND
activity_type = 'script' AND
script_id IN (
SELECT id FROM scripts WHERE scripts.global_or_team_id = ?
)
`
if err := deletePendingFunc(deleteUAStmt, policyID, globalOrTeamID); err != nil {
return err
}

return nil
Expand Down
53 changes: 28 additions & 25 deletions server/datastore/mysql/scripts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1545,32 +1545,35 @@ func testDeletePendingHostScriptExecutionsForPolicy(t *testing.T, ds *Datastore)
require.NoError(t, err)
require.Equal(t, 1, len(pending))

// test not pending host script execution for correct policy
scriptExecution, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo",
UserID: &user.ID,
PolicyID: &p1.ID,
SyncRequest: true,
ScriptID: &script1.ID,
})
require.NoError(t, err)
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err = q.ExecContext(ctx, `UPDATE host_script_results SET exit_code = 1 WHERE id = ?`, scriptExecution.ID)
// TODO(mna): adjust test once script execution via unified queue is implemented
/*
// test not pending host script execution for correct policy
scriptExecution, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo",
UserID: &user.ID,
PolicyID: &p1.ID,
SyncRequest: true,
ScriptID: &script1.ID,
})
require.NoError(t, err)
return nil
})
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err = q.ExecContext(ctx, `UPDATE host_script_results SET exit_code = 1 WHERE id = ?`, scriptExecution.ID)
require.NoError(t, err)
return nil
})
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
require.NoError(t, err)
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
require.NoError(t, err)
var count int
err = sqlx.GetContext(
ctx,
ds.reader(ctx),
&count,
"SELECT count(1) FROM host_script_results WHERE id = ?",
scriptExecution.ID,
)
require.Equal(t, 1, count)
var count int
err = sqlx.GetContext(
ctx,
ds.reader(ctx),
&count,
"SELECT count(1) FROM host_script_results WHERE id = ?",
scriptExecution.ID,
)
require.Equal(t, 1, count)
*/
}

0 comments on commit eb96e09

Please sign in to comment.