Skip to content

Commit

Permalink
nvproxy: try to propagate nvidia_poll()'s dataless_event_pending
Browse files Browse the repository at this point in the history
See new comment in nvproxy.frontendFD for context.

PiperOrigin-RevId: 656524586
  • Loading branch information
nixprime authored and gvisor-bot committed Jul 26, 2024
1 parent ed73825 commit 8db16e8
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 17 deletions.
1 change: 1 addition & 0 deletions pkg/sentry/devices/nvproxy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/abi/nvgpu",
"//pkg/atomicbitops",
"//pkg/cleanup",
"//pkg/context",
"//pkg/devutil",
Expand Down
84 changes: 69 additions & 15 deletions pkg/sentry/devices/nvproxy/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ package nvproxy

import (
"fmt"
"sync/atomic"

"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/abi/nvgpu"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/devutil"
Expand Down Expand Up @@ -76,7 +76,9 @@ func (dev *frontendDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.D
unix.Close(hostFD)
return nil, err
}
if err := fdnotifier.AddFD(int32(hostFD), &fd.queue); err != nil {
fd.internalEntry.Init(fd, waiter.AllEvents)
fd.internalQueue.EventRegister(&fd.internalEntry)
if err := fdnotifier.AddFD(int32(hostFD), &fd.internalQueue); err != nil {
unix.Close(hostFD)
return nil, err
}
Expand All @@ -102,8 +104,31 @@ type frontendFD struct {
hostFD int32
memmapFile frontendFDMemmapFile

queue waiter.Queue
haveMmapContext atomic.Bool `state:"nosave"`
// The driver's implementation of poll() for these files,
// kernel-open/nvidia/nv.c:nvidia_poll(), unsets
// nv_linux_file_private_t::dataless_event_pending if it's set. This makes
// notifications from dataless_event_pending edge-triggered; a host poll()
// or epoll_wait() that returns the notification consumes it, preventing
// future calls to poll() or epoll_wait() from observing the same
// notification again.
//
// This is problematic in gVisor: fdnotifier, which epoll_wait()s on an
// epoll instance that includes our hostFD, will forward notifications to
// registered waiters, but this typically only wakes up blocked task
// goroutines which will later call vfs.FileDescription.Readiness() to get
// the FD's most up-to-date state. If our implementation of Readiness()
// just polls the underlying host FD, it will no longer observe the
// consumed notification.
//
// To work around this, intercept all events from fdnotifier and cache them
// for the first following call to Readiness(), essentially replicating the
// driver's behavior.
internalQueue waiter.Queue
internalEntry waiter.Entry
cachedEvents atomicbitops.Uint64
appQueue waiter.Queue

haveMmapContext atomicbitops.Bool `state:"nosave"`

// clients are handles of clients owned by this frontendFD. clients is
// protected by dev.nvp.objsMu.
Expand All @@ -113,7 +138,7 @@ type frontendFD struct {
// Release implements vfs.FileDescriptionImpl.Release.
func (fd *frontendFD) Release(ctx context.Context) {
fdnotifier.RemoveFD(fd.hostFD)
fd.queue.Notify(waiter.EventHUp)
fd.appQueue.Notify(waiter.EventHUp)

fd.dev.nvp.fdsMu.Lock()
delete(fd.dev.nvp.frontendFDs, fd)
Expand All @@ -131,25 +156,54 @@ func (fd *frontendFD) Release(ctx context.Context) {

// EventRegister implements waiter.Waitable.EventRegister.
func (fd *frontendFD) EventRegister(e *waiter.Entry) error {
fd.queue.EventRegister(e)
if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
fd.queue.EventUnregister(e)
return err
}
fd.appQueue.EventRegister(e)
return nil
}

// EventUnregister implements waiter.Waitable.EventUnregister.
func (fd *frontendFD) EventUnregister(e *waiter.Entry) {
fd.queue.EventUnregister(e)
if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
panic(fmt.Sprint("UpdateFD:", err))
}
fd.appQueue.EventUnregister(e)
}

// Readiness implements waiter.Waitable.Readiness.
func (fd *frontendFD) Readiness(mask waiter.EventMask) waiter.EventMask {
return fdnotifier.NonBlockingPoll(fd.hostFD, mask)
for {
cachedEvents := waiter.EventMask(fd.cachedEvents.Load())
maskedEvents := cachedEvents & mask
if maskedEvents == 0 {
// Poll for all events and cache any not consumed by this call.
events := fdnotifier.NonBlockingPoll(fd.hostFD, waiter.AllEvents)
if unmaskedEvents := events &^ mask; unmaskedEvents != 0 {
fd.cacheEvents(unmaskedEvents)
}
return events & mask
}
if fd.cachedEvents.CompareAndSwap(uint64(cachedEvents), uint64(cachedEvents&^maskedEvents)) {
return maskedEvents
}
}
}

func (fd *frontendFD) cacheEvents(mask waiter.EventMask) {
for {
oldEvents := waiter.EventMask(fd.cachedEvents.Load())
newEvents := oldEvents | mask
if oldEvents == newEvents {
break
}
if fd.cachedEvents.CompareAndSwap(uint64(oldEvents), uint64(newEvents)) {
break
}
}
}

// NotifyEvent implements waiter.EventListener.NotifyEvent.
func (fd *frontendFD) NotifyEvent(mask waiter.EventMask) {
// Events must be cached before notifying fd.appQueue, in order to ensure
// that the first notified waiter to call fd.Readiness() sees the
// newly-cached events.
fd.cacheEvents(mask)
fd.appQueue.Notify(mask)
}

// Epollable implements vfs.FileDescriptionImpl.Epollable.
Expand Down
4 changes: 2 additions & 2 deletions pkg/waiter/waiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ const (
EventInternal EventMask = 0x1000
EventRdHUp EventMask = 0x2000 // POLLRDHUP

allEvents EventMask = 0x1f | EventRdNorm | EventWrNorm | EventRdHUp
AllEvents EventMask = 0x1f | EventRdNorm | EventWrNorm | EventRdHUp
ReadableEvents EventMask = EventIn | EventRdNorm
WritableEvents EventMask = EventOut | EventWrNorm
)
Expand All @@ -86,7 +86,7 @@ const (
// from the Linux events e, which is in the format used by poll(2).
func EventMaskFromLinux(e uint32) EventMask {
// Our flag definitions are currently identical to Linux.
return EventMask(e) & allEvents
return EventMask(e) & AllEvents
}

// ToLinux returns e in the format used by Linux poll(2).
Expand Down

0 comments on commit 8db16e8

Please sign in to comment.