Skip to content

Commit

Permalink
Register TPU v5e devices when booting gVisor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587251672
  • Loading branch information
milantracy authored and gvisor-bot committed Dec 2, 2023
1 parent 8042c6f commit 126ee58
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 25 deletions.
64 changes: 39 additions & 25 deletions runsc/boot/vfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ const (
// SelfFilestorePrefix is the prefix of the self filestore file name.
const SelfFilestorePrefix = ".gvisor.filestore."

const (
pciPathGlobTPUv4 = "/sys/devices/pci0000:00/*/accel/accel*"
pciPathGlobTPUv5 = "/sys/devices/pci0000:00/*/vfio-dev/vfio*"
)

// SelfFilestorePath returns the path at which the self filestore file is
// stored for a given mount.
func SelfFilestorePath(mountSrc, sandboxID string) string {
Expand Down Expand Up @@ -1250,37 +1255,46 @@ func registerTPUDevice(vfsObj *vfs.VirtualFilesystem, minor uint32, deviceID int
}
}

// pathGlobToPathRegex is a map that points a TPU PCI path glob to its path regex.
// TPU v4 devices are accessible via /sys/devices/pci0000:00/<pci_address>/accel/accel# on the host.
// TPU v5 devices are accessible via at /sys/devices/pci0000:00/<pci_address>/vfio-dev/vfio# on the host.
var pathGlobToPathRegex = map[string]string{
pciPathGlobTPUv4: `^/sys/devices/pci0000:00/\d+:\d+:\d+\.\d+/accel/accel(\d+)$`,
pciPathGlobTPUv5: `^/sys/devices/pci0000:00/\d+:\d+:\d+\.\d+/vfio-dev/vfio(\d+)$`,
}

func tpuProxyRegisterDevices(info *containerInfo, vfsObj *vfs.VirtualFilesystem) error {
if !specutils.TPUProxyIsEnabled(info.spec, info.conf) {
return nil
}
// At this point /sys/devices/pci0000:00/<pci_address>/accel/accel# contains
// all the TPU devices on the host. Enumerate them and register TPU devices.
pciAddrs, err := filepath.Glob("/sys/devices/pci0000:00/*/accel/accel*")
if err != nil {
return fmt.Errorf("enumerating PCI device files: %w", err)
}
pciPathRegex := regexp.MustCompile(`^/sys/devices/pci0000:00/\d+:\d+:\d+\.\d+/accel/accel(\d+)$`)
for _, pciPath := range pciAddrs {
ms := pciPathRegex.FindStringSubmatch(pciPath)
if ms == nil {
continue
}
deviceNum, err := strconv.ParseUint(ms[1], 10, 32)
// Enumerate all potential PCI paths where TPU devices are available and register the found TPU devices.
for pciPathGlobal, pathRegex := range pathGlobToPathRegex {
pciAddrs, err := filepath.Glob(pciPathGlobal)
if err != nil {
return fmt.Errorf("parsing PCI device number: %w", err)
return fmt.Errorf("enumerating PCI device files: %w", err)
}
var deviceIDBytes []byte
if deviceIDBytes, err = os.ReadFile(path.Join(pciPath, "device/device")); err != nil {
return fmt.Errorf("reading PCI device ID: %w", err)
}
deviceIDStr := strings.Replace(string(deviceIDBytes), "0x", "", -1)
deviceID, err := strconv.ParseInt(strings.TrimSpace(deviceIDStr), 16, 64)
if err != nil {
return fmt.Errorf("parsing PCI device ID: %w", err)
}
if err := registerTPUDevice(vfsObj, uint32(deviceNum), deviceID); err != nil {
return fmt.Errorf("registering accel driver: %w", err)
pciPathRegex := regexp.MustCompile(pathRegex)
for _, pciPath := range pciAddrs {
ms := pciPathRegex.FindStringSubmatch(pciPath)
if ms == nil {
continue
}
deviceNum, err := strconv.ParseUint(ms[1], 10, 32)
if err != nil {
return fmt.Errorf("parsing PCI device number: %w", err)
}
var deviceIDBytes []byte
if deviceIDBytes, err = os.ReadFile(path.Join(pciPath, "device/device")); err != nil {
return fmt.Errorf("reading PCI device ID: %w", err)
}
deviceIDStr := strings.Replace(string(deviceIDBytes), "0x", "", -1)
deviceID, err := strconv.ParseInt(strings.TrimSpace(deviceIDStr), 16, 64)
if err != nil {
return fmt.Errorf("parsing PCI device ID: %w", err)
}
if err := registerTPUDevice(vfsObj, uint32(deviceNum), deviceID); err != nil {
return fmt.Errorf("registering TPU driver: %w", err)
}
}
}
return nil
Expand Down
47 changes: 47 additions & 0 deletions runsc/boot/vfs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
package boot

import (
"path/filepath"
"regexp"
"slices"
"testing"

specs "github.com/opencontainers/runtime-spec/specs-go"
Expand Down Expand Up @@ -96,3 +99,47 @@ func TestGetMountAccessType(t *testing.T) {
})
}
}

func TestTPUPath(t *testing.T) {
for _, tst := range []struct {
name string
pathGlob string
path string
submatch []string
}{
{
name: "TPUv4PCIPathMatch",
pathGlob: pciPathGlobTPUv4,
path: "/sys/devices/pci0000:00/0000:00:01.0/accel/accel16",
submatch: []string{"/sys/devices/pci0000:00/0000:00:01.0/accel/accel16", "16"},
},
{
name: "TPUv4PCIPathNoMatch",
pathGlob: pciPathGlobTPUv4,
path: "/sys/devices/pci0000:00/0000:00:01.0/accel/123",
submatch: nil,
},
{
name: "TPUv5PCIPathMatch",
pathGlob: pciPathGlobTPUv5,
path: "/sys/devices/pci0000:00/0000:00:05.0/vfio-dev/vfio20",
submatch: []string{"/sys/devices/pci0000:00/0000:00:05.0/vfio-dev/vfio20", "20"},
},
{
name: "TPUv5PCIPathNoMatch",
pathGlob: pciPathGlobTPUv5,
path: "/sys/devices/pci0000:00/0000:00:05.0/vfio/vfio20",
submatch: nil,
},
} {
t.Run(tst.name, func(t *testing.T) {
if _, err := filepath.Glob(tst.pathGlob); err != nil {
t.Errorf("Malformed path glob: %v", err)
}
pathRegex := regexp.MustCompile(pathGlobToPathRegex[tst.pathGlob])
if submatch := pathRegex.FindStringSubmatch(tst.path); !slices.Equal(submatch, tst.submatch) {
t.Errorf("Match TPU PCI path, got: %v, want: %v", submatch, tst.submatch)
}
})
}
}

0 comments on commit 126ee58

Please sign in to comment.