From 126ee58746d10ae5064e51dec96cf580b8116875 Mon Sep 17 00:00:00 2001 From: Jing Chen Date: Sat, 2 Dec 2023 00:01:59 -0800 Subject: [PATCH] Register TPU v5e devices when booting gVisor. PiperOrigin-RevId: 587251672 --- runsc/boot/vfs.go | 64 +++++++++++++++++++++++++----------------- runsc/boot/vfs_test.go | 47 +++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 25 deletions(-) diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index da33a97990..6a9080043d 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -71,6 +71,11 @@ const ( // SelfFilestorePrefix is the prefix of the self filestore file name. const SelfFilestorePrefix = ".gvisor.filestore." +const ( + pciPathGlobTPUv4 = "/sys/devices/pci0000:00/*/accel/accel*" + pciPathGlobTPUv5 = "/sys/devices/pci0000:00/*/vfio-dev/vfio*" +) + // SelfFilestorePath returns the path at which the self filestore file is // stored for a given mount. func SelfFilestorePath(mountSrc, sandboxID string) string { @@ -1250,37 +1255,46 @@ func registerTPUDevice(vfsObj *vfs.VirtualFilesystem, minor uint32, deviceID int } } +// pathGlobToPathRegex is a map that points a TPU PCI path glob to its path regex. +// TPU v4 devices are accessible via /sys/devices/pci0000:00//accel/accel# on the host. +// TPU v5 devices are accessible via at /sys/devices/pci0000:00//vfio-dev/vfio# on the host. +var pathGlobToPathRegex = map[string]string{ + pciPathGlobTPUv4: `^/sys/devices/pci0000:00/\d+:\d+:\d+\.\d+/accel/accel(\d+)$`, + pciPathGlobTPUv5: `^/sys/devices/pci0000:00/\d+:\d+:\d+\.\d+/vfio-dev/vfio(\d+)$`, +} + func tpuProxyRegisterDevices(info *containerInfo, vfsObj *vfs.VirtualFilesystem) error { if !specutils.TPUProxyIsEnabled(info.spec, info.conf) { return nil } - // At this point /sys/devices/pci0000:00//accel/accel# contains - // all the TPU devices on the host. Enumerate them and register TPU devices. - pciAddrs, err := filepath.Glob("/sys/devices/pci0000:00/*/accel/accel*") - if err != nil { - return fmt.Errorf("enumerating PCI device files: %w", err) - } - pciPathRegex := regexp.MustCompile(`^/sys/devices/pci0000:00/\d+:\d+:\d+\.\d+/accel/accel(\d+)$`) - for _, pciPath := range pciAddrs { - ms := pciPathRegex.FindStringSubmatch(pciPath) - if ms == nil { - continue - } - deviceNum, err := strconv.ParseUint(ms[1], 10, 32) + // Enumerate all potential PCI paths where TPU devices are available and register the found TPU devices. + for pciPathGlobal, pathRegex := range pathGlobToPathRegex { + pciAddrs, err := filepath.Glob(pciPathGlobal) if err != nil { - return fmt.Errorf("parsing PCI device number: %w", err) + return fmt.Errorf("enumerating PCI device files: %w", err) } - var deviceIDBytes []byte - if deviceIDBytes, err = os.ReadFile(path.Join(pciPath, "device/device")); err != nil { - return fmt.Errorf("reading PCI device ID: %w", err) - } - deviceIDStr := strings.Replace(string(deviceIDBytes), "0x", "", -1) - deviceID, err := strconv.ParseInt(strings.TrimSpace(deviceIDStr), 16, 64) - if err != nil { - return fmt.Errorf("parsing PCI device ID: %w", err) - } - if err := registerTPUDevice(vfsObj, uint32(deviceNum), deviceID); err != nil { - return fmt.Errorf("registering accel driver: %w", err) + pciPathRegex := regexp.MustCompile(pathRegex) + for _, pciPath := range pciAddrs { + ms := pciPathRegex.FindStringSubmatch(pciPath) + if ms == nil { + continue + } + deviceNum, err := strconv.ParseUint(ms[1], 10, 32) + if err != nil { + return fmt.Errorf("parsing PCI device number: %w", err) + } + var deviceIDBytes []byte + if deviceIDBytes, err = os.ReadFile(path.Join(pciPath, "device/device")); err != nil { + return fmt.Errorf("reading PCI device ID: %w", err) + } + deviceIDStr := strings.Replace(string(deviceIDBytes), "0x", "", -1) + deviceID, err := strconv.ParseInt(strings.TrimSpace(deviceIDStr), 16, 64) + if err != nil { + return fmt.Errorf("parsing PCI device ID: %w", err) + } + if err := registerTPUDevice(vfsObj, uint32(deviceNum), deviceID); err != nil { + return fmt.Errorf("registering TPU driver: %w", err) + } } } return nil diff --git a/runsc/boot/vfs_test.go b/runsc/boot/vfs_test.go index 85380c489f..570b4b2991 100644 --- a/runsc/boot/vfs_test.go +++ b/runsc/boot/vfs_test.go @@ -15,6 +15,9 @@ package boot import ( + "path/filepath" + "regexp" + "slices" "testing" specs "github.com/opencontainers/runtime-spec/specs-go" @@ -96,3 +99,47 @@ func TestGetMountAccessType(t *testing.T) { }) } } + +func TestTPUPath(t *testing.T) { + for _, tst := range []struct { + name string + pathGlob string + path string + submatch []string + }{ + { + name: "TPUv4PCIPathMatch", + pathGlob: pciPathGlobTPUv4, + path: "/sys/devices/pci0000:00/0000:00:01.0/accel/accel16", + submatch: []string{"/sys/devices/pci0000:00/0000:00:01.0/accel/accel16", "16"}, + }, + { + name: "TPUv4PCIPathNoMatch", + pathGlob: pciPathGlobTPUv4, + path: "/sys/devices/pci0000:00/0000:00:01.0/accel/123", + submatch: nil, + }, + { + name: "TPUv5PCIPathMatch", + pathGlob: pciPathGlobTPUv5, + path: "/sys/devices/pci0000:00/0000:00:05.0/vfio-dev/vfio20", + submatch: []string{"/sys/devices/pci0000:00/0000:00:05.0/vfio-dev/vfio20", "20"}, + }, + { + name: "TPUv5PCIPathNoMatch", + pathGlob: pciPathGlobTPUv5, + path: "/sys/devices/pci0000:00/0000:00:05.0/vfio/vfio20", + submatch: nil, + }, + } { + t.Run(tst.name, func(t *testing.T) { + if _, err := filepath.Glob(tst.pathGlob); err != nil { + t.Errorf("Malformed path glob: %v", err) + } + pathRegex := regexp.MustCompile(pathGlobToPathRegex[tst.pathGlob]) + if submatch := pathRegex.FindStringSubmatch(tst.path); !slices.Equal(submatch, tst.submatch) { + t.Errorf("Match TPU PCI path, got: %v, want: %v", submatch, tst.submatch) + } + }) + } +}