diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 5494726293..752367adac 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -36,6 +36,7 @@ import ( "errors" "fmt" "io" + "math" "path/filepath" "time" @@ -1439,12 +1440,26 @@ func (k *Kernel) decRunningTasks() { // WaitExited blocks until all tasks in k have exited. No tasks can be created // after WaitExited returns. func (k *Kernel) WaitExited() { - k.tasks.mu.Lock() - defer k.tasks.mu.Unlock() - k.tasks.noNewTasksIfZeroLive = true - for k.tasks.liveTasks != 0 { - k.tasks.zeroLiveTasksCond.Wait() + // Ensure that the most significant bit of k.tasks.liveTasks is unset, + // preventing k.tasks.newTask() from transitioning k.tasks.liveTasks out of + // 0. + for { + liveTasks := k.tasks.liveTasks.Load() + if liveTasks == 0 { + return + } + if liveTasks > 0 { + break + } + newLiveTasks := liveTasks &^ math.MinInt64 + if k.tasks.liveTasks.CompareAndSwap(liveTasks, newLiveTasks) { + if newLiveTasks == 0 { + close(k.tasks.zeroLiveTasksC) + } + break + } } + <-k.tasks.zeroLiveTasksC } // Kill requests that all tasks in k immediately exit as if group exiting with diff --git a/pkg/sentry/kernel/kernel_state.go b/pkg/sentry/kernel/kernel_state.go index a47d164e93..a7b781dc20 100644 --- a/pkg/sentry/kernel/kernel_state.go +++ b/pkg/sentry/kernel/kernel_state.go @@ -16,13 +16,27 @@ package kernel import ( "context" + "math" "gvisor.dev/gvisor/pkg/tcpip" ) +// saveLiveTasks is invoked by stateify. +func (ts *TaskSet) saveLiveTasks() int64 { + // The MSB, which is cleared by Kernel.WaitExited(), is never saved and is + // always set again after restore, since whether Kernel.WaitExited() was + // called before checkpointing is not intended to apply after restore. + return ts.liveTasks.Load() &^ math.MinInt64 +} + +// loadLiveTasks is invoked by stateify. +func (ts *TaskSet) loadLiveTasks(_ context.Context, liveTasks int64) { + ts.liveTasks.Store(liveTasks | math.MinInt64) +} + // afterLoad is invoked by stateify. func (ts *TaskSet) afterLoad(_ context.Context) { - ts.zeroLiveTasksCond.L = &ts.mu + ts.zeroLiveTasksC = make(chan struct{}, 0) } // saveDanglingEndpoints is invoked by stateify. diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index 10b616f025..255d2042d7 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -103,12 +103,9 @@ func (t *Task) run(threadID uintptr) { t.p.Release() ts := t.tg.pidns.owner - ts.mu.Lock() - ts.liveTasks-- - if ts.liveTasks == 0 { - ts.zeroLiveTasksCond.Broadcast() + if ts.liveTasks.Add(-1) == 0 { + close(ts.zeroLiveTasksC) } - ts.mu.Unlock() ts.runningGoroutines.Done() // Deferring this store triggers a false positive in the race diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index ded77133f7..7d378240a4 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -226,20 +226,24 @@ func (ts *TaskSet) newTask(ctx context.Context, cfg *TaskConfig) (*Task, error) // we're in uncharted territory and can return whatever we want. return nil, linuxerr.EINTR } - if ts.liveTasks == 0 && ts.noNewTasksIfZeroLive { + if ts.liveTasks.Add(1) == 1 { + // ts.mu is locked, so this can't race with concurrent calls to + // ts.newTask(). + ts.liveTasks.Add(-1) // Since liveTasks == 0, our caller cannot be a task goroutine invoking // a syscall, so it's safe to return a non-errno error that is more // explanatory. return nil, fmt.Errorf("task creation disabled after Kernel.WaitExited() may have returned") } if err := ts.assignTIDsLocked(t); err != nil { + if ts.liveTasks.Add(-1) == 0 { + close(ts.zeroLiveTasksC) + } return nil, err } // Below this point, newTask is expected not to fail (there is no rollback // of assignTIDsLocked or any of the following). - ts.liveTasks++ - // Logging on t's behalf will panic if t.logPrefix hasn't been // initialized. This is the earliest point at which we can do so // (since t now has thread IDs). diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 6c7d7f7829..3f195b1b3d 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -16,6 +16,7 @@ package kernel import ( "fmt" + "math" "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -77,16 +78,12 @@ type TaskSet struct { stopCount int32 `state:"nosave"` // liveTasks is the number of tasks in the TaskSet whose goroutines have - // not exited. liveTasks is protected by mu. - liveTasks uint32 + // not exited. + liveTasks atomicbitops.Int64 `state:".(int64)"` - // If noNewTasksIfZeroLive is true and liveTasks is zero, calls to - // Kernel.NewTask() will fail. noNewTasksIfZeroLive is protected by mu. - noNewTasksIfZeroLive bool - - // zeroLiveTasksCond is broadcast when liveTasks transitions from non-zero - // to zero. - zeroLiveTasksCond sync.Cond `state:"nosave"` + // zeroLiveTasksC is closed when liveTasks transitions from non-zero to + // zero. + zeroLiveTasksC chan struct{} `state:"manual"` // runningGoroutines is the number of running task goroutines in the // TaskSet. @@ -106,8 +103,14 @@ type TaskSet struct { // newTaskSet returns a new, empty TaskSet. func newTaskSet(pidns *PIDNamespace) *TaskSet { - ts := &TaskSet{Root: pidns} - ts.zeroLiveTasksCond.L = &ts.mu + ts := &TaskSet{ + Root: pidns, + zeroLiveTasksC: make(chan struct{}, 0), + } + // Set the most significant bit of ts.liveTasks to make it non-zero, and + // therefore allow TaskSet.newTask() to proceed even if there are no live + // tasks; it will be cleared by Kernel.WaitExited(). + ts.liveTasks.Store(math.MinInt64) pidns.owner = ts return ts }