Skip to content

Commit

Permalink
Delete pending script executions when the underlying script is edited…
Browse files Browse the repository at this point in the history
… or deleted (#23520)

#21888 

# Checklist for submitter

- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/Committing-Changes.md#changes-files)
for more information.
- [x] Added/updated tests
- [x] Manual QA for all new/changed functionality
  • Loading branch information
iansltx authored Nov 7, 2024
1 parent 48e1d7b commit c797fb7
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 39 deletions.
1 change: 1 addition & 0 deletions changes/21888-dequeue-pending-scripts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* Cancelled pending script executions when a script is edited or deleted.
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,14 @@ const DeleteSetupExperienceScriptModal = ({
<Modal className={baseClass} title="Delete setup script" onExit={onExit}>
<>
<p>
The script <b>{scriptName}</b> will still run on pending hosts.
This action will cancel any pending script execution for{" "}
<b>{scriptName}</b>.
</p>
<p>
If the script is currently running on a host it will still complete,
but results won&apos;t appear in Fleet.
</p>
<p>You cannot undo this action.</p>
<div className="modal-cta-wrap">
<Button type="button" onClick={onDelete} variant="alert">
Delete
Expand Down
88 changes: 64 additions & 24 deletions server/datastore/mysql/scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"time"
"unicode/utf8"

"github.com/fleetdm/fleet/v4/pkg/scripts"
constants "github.com/fleetdm/fleet/v4/pkg/scripts"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/google/uuid"
Expand Down Expand Up @@ -223,7 +223,7 @@ func (ds *Datastore) ListPendingHostScriptExecutions(ctx context.Context, hostID
created_at ASC`

var results []*fleet.HostScriptResult
seconds := int(scripts.MaxServerWaitTime.Seconds())
seconds := int(constants.MaxServerWaitTime.Seconds())
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, listStmt, hostID, seconds); err != nil {
return nil, ctxerr.Wrap(ctx, err, "list pending host script executions")
}
Expand Down Expand Up @@ -441,21 +441,32 @@ WHERE
var errDeleteScriptWithAssociatedPolicy = &fleet.ConflictError{Message: "Couldn't delete. Policy automation uses this script. Please remove this script from associated policy automations and try again."}

func (ds *Datastore) DeleteScript(ctx context.Context, id uint) error {
_, err := ds.writer(ctx).ExecContext(ctx, `DELETE FROM scripts WHERE id = ?`, id)
if err != nil {
if isMySQLForeignKey(err) {
// Check if the script is referenced by a policy automation.
var count int
if err := sqlx.GetContext(ctx, ds.reader(ctx), &count, `SELECT COUNT(*) FROM policies WHERE script_id = ?`, id); err != nil {
return ctxerr.Wrapf(ctx, err, "getting reference from policies")
}
if count > 0 {
return ctxerr.Wrap(ctx, errDeleteScriptWithAssociatedPolicy, "delete script")
return ds.withTx(ctx, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, `DELETE FROM host_script_results WHERE script_id = ?
AND exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)`,
id, int(constants.MaxServerWaitTime.Seconds()),
)
if err != nil {
return ctxerr.Wrapf(ctx, err, "cancel pending script executions")
}

_, err = tx.ExecContext(ctx, `DELETE FROM scripts WHERE id = ?`, id)
if err != nil {
if isMySQLForeignKey(err) {
// Check if the script is referenced by a policy automation.
var count int
if err := sqlx.GetContext(ctx, tx, &count, `SELECT COUNT(*) FROM policies WHERE script_id = ?`, id); err != nil {
return ctxerr.Wrapf(ctx, err, "getting reference from policies")
}
if count > 0 {
return ctxerr.Wrap(ctx, errDeleteScriptWithAssociatedPolicy, "delete script")
}
}
return ctxerr.Wrap(ctx, err, "delete script")
}
return ctxerr.Wrap(ctx, err, "delete script")
}
return nil

return nil
})
}

func (ds *Datastore) ListScripts(ctx context.Context, teamID *uint, opt fleet.ListOptions) ([]*fleet.Script, *fleet.PaginationMetadata, error) {
Expand Down Expand Up @@ -637,6 +648,10 @@ WHERE
`
const unsetAllScriptsFromPolicies = `UPDATE policies SET script_id = NULL WHERE team_id = ?`

const clearAllPendingExecutions = `DELETE FROM host_script_results WHERE
exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)
AND script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)`

const unsetScriptsNotInListFromPolicies = `
UPDATE policies SET script_id = NULL
WHERE script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))
Expand All @@ -650,6 +665,10 @@ WHERE
name NOT IN (?)
`

const clearPendingExecutionsNotInList = `DELETE FROM host_script_results WHERE
exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)
AND script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))`

const insertNewOrEditedScript = `
INSERT INTO
scripts (
Expand All @@ -658,9 +677,13 @@ INSERT INTO
VALUES
(?, ?, ?, ?)
ON DUPLICATE KEY UPDATE
script_content_id = VALUES(script_content_id)
script_content_id = VALUES(script_content_id), id=LAST_INSERT_ID(id)
`

const clearPendingExecutionsWithObsoleteScript = `DELETE FROM host_script_results WHERE
exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)
AND script_id = ? AND script_content_id != ?`

const loadInsertedScripts = `SELECT id, team_id, name FROM scripts WHERE global_or_team_id = ?`

// use a team id of 0 if no-team
Expand Down Expand Up @@ -704,11 +727,13 @@ ON DUPLICATE KEY UPDATE
}

var (
scriptsStmt string
scriptsArgs []any
policiesStmt string
policiesArgs []any
err error
scriptsStmt string
scriptsArgs []any
policiesStmt string
policiesArgs []any
executionsStmt string
executionsArgs []any
err error
)
if len(keepNames) > 0 {
// delete the obsolete scripts
Expand All @@ -721,16 +746,27 @@ ON DUPLICATE KEY UPDATE
if err != nil {
return ctxerr.Wrap(ctx, err, "build statement to unset obsolete scripts from policies")
}

executionsStmt, executionsArgs, err = sqlx.In(clearPendingExecutionsNotInList, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames)
if err != nil {
return ctxerr.Wrap(ctx, err, "build statement to clear pending script executions from obsolete scripts")
}
} else {
scriptsStmt = deleteAllScriptsInTeam
scriptsArgs = []any{globalOrTeamID}

policiesStmt = unsetAllScriptsFromPolicies
policiesArgs = []any{globalOrTeamID}

executionsStmt = clearAllPendingExecutions
executionsArgs = []any{int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID}
}
if _, err := tx.ExecContext(ctx, policiesStmt, policiesArgs...); err != nil {
return ctxerr.Wrap(ctx, err, "unset obsolete scripts from policies")
}
if _, err := tx.ExecContext(ctx, executionsStmt, executionsArgs...); err != nil {
return ctxerr.Wrap(ctx, err, "clear obsolete script pending executions")
}
if _, err := tx.ExecContext(ctx, scriptsStmt, scriptsArgs...); err != nil {
return ctxerr.Wrap(ctx, err, "delete obsolete scripts")
}
Expand All @@ -741,11 +777,15 @@ ON DUPLICATE KEY UPDATE
if err != nil {
return ctxerr.Wrapf(ctx, err, "inserting script contents for script with name %q", s.Name)
}
id, _ := scRes.LastInsertId()
if _, err := tx.ExecContext(ctx, insertNewOrEditedScript, tmID, globalOrTeamID, s.Name,
uint(id)); err != nil { //nolint:gosec // dismiss G115
contentID, _ := scRes.LastInsertId()
insertRes, err := tx.ExecContext(ctx, insertNewOrEditedScript, tmID, globalOrTeamID, s.Name, uint(contentID)) //nolint:gosec // dismiss G115
if err != nil {
return ctxerr.Wrapf(ctx, err, "insert new/edited script with name %q", s.Name)
}
scriptID, _ := insertRes.LastInsertId()
if _, err := tx.ExecContext(ctx, clearPendingExecutionsWithObsoleteScript, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil {
return ctxerr.Wrapf(ctx, err, "clear obsolete pending script executions with name %q", s.Name)
}
}

if err := sqlx.SelectContext(ctx, tx, &insertedScripts, loadInsertedScripts, globalOrTeamID); err != nil {
Expand Down
60 changes: 60 additions & 0 deletions server/datastore/mysql/scripts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,11 +654,48 @@ VALUES
require.NoError(t, err)
require.True(t, r)
})

t.Run("script deletion cancels pending script runs", func(t *testing.T) {
insertResults(t, 43, scripts[3], now.Add(-2*time.Minute), "execution-4-4", nil)
pending, err := ds.ListPendingHostScriptExecutions(ctx, 43)
require.NoError(t, err)
require.Len(t, pending, 1)

err = ds.DeleteScript(ctx, scripts[3].ID)
require.NoError(t, err)

pending, err = ds.ListPendingHostScriptExecutions(ctx, 43)
require.NoError(t, err)
require.Len(t, pending, 0)
})
}

func testBatchSetScripts(t *testing.T, ds *Datastore) {
ctx := context.Background()

now := time.Now().UTC().Truncate(time.Second)
insertResults := func(t *testing.T, hostID uint, scriptID uint, createdAt time.Time, execID string, exitCode *int64) {
stmt := `
INSERT INTO
host_script_results (%s host_id, created_at, execution_id, exit_code, output)
VALUES
(%s ?,?,?,?,?)`

args := []interface{}{}
if scriptID == 0 {
stmt = fmt.Sprintf(stmt, "", "")
} else {
stmt = fmt.Sprintf(stmt, "script_id,", "?,")
args = append(args, scriptID)
}
args = append(args, hostID, createdAt, execID, exitCode, "")

ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, stmt, args...)
return err
})
}

applyAndExpect := func(newSet []*fleet.Script, tmID *uint, want []*fleet.Script) map[string]uint {
responseFromSet, err := ds.BatchSetScripts(ctx, tmID, newSet)
require.NoError(t, err)
Expand Down Expand Up @@ -781,6 +818,16 @@ func testBatchSetScripts(t *testing.T, ds *Datastore) {
require.NoError(t, err)
require.Equal(t, n1WithTeamID, *teamPolicy.ScriptID)

// add pending scripts on team and no-team and confirm they're shown as pending
insertResults(t, 44, n1WithTeamID, now.Add(-2*time.Minute), "execution-n1t1-1", nil)
insertResults(t, 45, n1WithNoTeamId, now.Add(-2*time.Minute), "execution-n1nt1-1", nil)
pending, err := ds.ListPendingHostScriptExecutions(ctx, 44)
require.NoError(t, err)
require.Len(t, pending, 1)
pending, err = ds.ListPendingHostScriptExecutions(ctx, 45)
require.NoError(t, err)
require.Len(t, pending, 1)

// clear scripts for tm1
applyAndExpect(nil, ptr.Uint(1), nil)

Expand All @@ -794,6 +841,14 @@ func testBatchSetScripts(t *testing.T, ds *Datastore) {
require.NoError(t, err)
require.Equal(t, n1WithNoTeamId, *noTeamPolicy.ScriptID)

// team script should no longer be pending, no-team script should still be pending
pending, err = ds.ListPendingHostScriptExecutions(ctx, 44)
require.NoError(t, err)
require.Len(t, pending, 0)
pending, err = ds.ListPendingHostScriptExecutions(ctx, 45)
require.NoError(t, err)
require.Len(t, pending, 1)

// apply only new scripts to no-team
applyAndExpect([]*fleet.Script{
{Name: "N4", ScriptContents: "C4"},
Expand All @@ -812,6 +867,11 @@ func testBatchSetScripts(t *testing.T, ds *Datastore) {
noTeamPolicy, err = ds.Policy(ctx, noTeamPolicy.ID)
require.NoError(t, err)
require.Nil(t, noTeamPolicy.ScriptID)

// no-team script should no longer be pending
pending, err = ds.ListPendingHostScriptExecutions(ctx, 45)
require.NoError(t, err)
require.Len(t, pending, 0)
}

func testLockHostViaScript(t *testing.T, ds *Datastore) {
Expand Down
18 changes: 5 additions & 13 deletions server/service/integration_enterprise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6368,25 +6368,17 @@ func (s *integrationEnterpriseTestSuite) TestRunHostSavedScript() {
errMsg = extractServerErrorText(res.Body)
require.Contains(t, errMsg, `One of 'script_id', 'script_contents', or 'script_name' is required.`)

// deleting the saved script does not impact the pending script
// deleting the saved script should delete the pending script
s.Do("DELETE", fmt.Sprintf("/api/latest/fleet/scripts/%d", savedNoTmScript.ID), nil, http.StatusNoContent)

// script id is now nil, but otherwise execution request is the same
scriptResultResp = getScriptResultResponse{}
s.DoJSON("GET", "/api/latest/fleet/scripts/results/"+runSyncResp.ExecutionID, nil, http.StatusOK, &scriptResultResp)
require.Equal(t, host.ID, scriptResultResp.HostID)
require.Equal(t, "echo 'no team'", scriptResultResp.ScriptContents)
require.Nil(t, scriptResultResp.ExitCode)
require.False(t, scriptResultResp.HostTimeout)
require.Contains(t, scriptResultResp.Message, fleet.RunScriptAlreadyRunningErrMsg)
require.Nil(t, scriptResultResp.ScriptID)
s.DoJSON("GET", "/api/latest/fleet/scripts/results/"+runSyncResp.ExecutionID, nil, http.StatusNotFound, &scriptResultResp)

// Verify that we can't enqueue more than 1k scripts

// Make the host offline so that scripts enqueue
err = s.ds.MarkHostsSeen(ctx, []uint{host.ID}, time.Now().Add(-time.Hour))
require.NoError(t, err)
for i := 0; i < 1000; i++ {
for i := 1; i <= 1000; i++ {
script, err := s.ds.NewScript(ctx, &fleet.Script{
TeamID: nil,
Name: fmt.Sprintf("script_1k_%d.sh", i),
Expand All @@ -6400,8 +6392,8 @@ func (s *integrationEnterpriseTestSuite) TestRunHostSavedScript() {

script, err := s.ds.NewScript(ctx, &fleet.Script{
TeamID: nil,
Name: "script_1k_1000.sh",
ScriptContents: "echo 1000",
Name: "script_1k_1001.sh",
ScriptContents: "echo 1001",
})
require.NoError(t, err)

Expand Down
2 changes: 1 addition & 1 deletion server/service/scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func (svc *Service) RunHostScript(ctx context.Context, request *fleet.HostScript
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "list host pending script executions")
}
if len(pending) > maxPending {
if len(pending) >= maxPending {
return nil, fleet.NewInvalidArgumentError(
"script_id", "cannot queue more than 1000 scripts per host",
).WithStatus(http.StatusConflict)
Expand Down

0 comments on commit c797fb7

Please sign in to comment.