diff --git a/drivers/shared/executor/procstats/list_test.go b/drivers/shared/executor/procstats/list_test.go index 9e9588d347b..4e34e6e3a8f 100644 --- a/drivers/shared/executor/procstats/list_test.go +++ b/drivers/shared/executor/procstats/list_test.go @@ -63,31 +63,26 @@ func Test_list(t *testing.T) { name string needles int haystack int - expect int }{ { name: "minimal", needles: 2, haystack: 10, - expect: 16, }, { name: "small needles small haystack", needles: 5, haystack: 200, - expect: 212, }, { name: "small needles large haystack", needles: 10, haystack: 1000, - expect: 1022, }, { name: "moderate needles giant haystack", needles: 20, haystack: 2000, - expect: 2042, }, } @@ -100,11 +95,10 @@ func Test_list(t *testing.T) { return procs, nil } - result, examined := list(executorPID, lister) + result := list(executorPID, lister) must.SliceContainsAll(t, expect, result.Slice(), must.Sprintf("exp: %v; got: %v", expect, result), ) - must.Eq(t, tc.expect, examined) }) } } diff --git a/drivers/shared/executor/procstats/procstats.go b/drivers/shared/executor/procstats/procstats.go index a21e97cefa6..a70f7d117cf 100644 --- a/drivers/shared/executor/procstats/procstats.go +++ b/drivers/shared/executor/procstats/procstats.go @@ -82,51 +82,40 @@ func Aggregate(systemStats *cpustats.Tracker, procStats ProcUsages) *drivers.Tas } } -func list(executorPID int, processes func() ([]ps.Process, error)) (set.Collection[ProcessID], int) { - family := set.From([]int{executorPID}) +func list(executorPID int, processes func() ([]ps.Process, error)) set.Collection[ProcessID] { + processFamily := set.From([]ProcessID{executorPID}) - all, err := processes() + allPids, err := processes() if err != nil { - return family, 0 + return processFamily } - parents, examined := mapping(all) - examined += gather(family, parents, executorPID) - - return family, examined -} - -func gather(family set.Collection[int], parents map[int]set.Collection[int], parent int) int { - examined := 0 - candidates, ok := parents[parent] - if !ok { - return examined - } - for _, candidate := range candidates.Slice() { - examined++ - family.Insert(candidate) - examined += gather(family, parents, candidate) + // A mapping of pids to their parent pids. It is used to build the process + // tree of the executing task + pidsRemaining := make(map[int]int, len(allPids)) + for _, pid := range allPids { + pidsRemaining[pid.Pid()] = pid.PPid() } - return examined -} - -// mapping builds a reverse map of parent to children -func mapping(all []ps.Process) (map[int]set.Collection[int], int) { + for { + // flag to indicate if we have found a match + foundNewPid := false - parents := map[int]set.Collection[int]{} - examined := 0 + for pid, ppid := range pidsRemaining { + childPid := processFamily.Contains(ppid) - for _, candidate := range all { - if candidate != nil { - examined++ - if children, ok := parents[candidate.PPid()]; ok { - children.Insert(candidate.Pid()) - } else { - parents[candidate.PPid()] = set.From([]int{candidate.Pid()}) + // checking if the pid is a child of any of the parents + if childPid { + processFamily.Insert(pid) + delete(pidsRemaining, pid) + foundNewPid = true } } + + if !foundNewPid { + break + } } - return parents, examined + return processFamily }