From d696151868f48f74277ed0e776d170fff7e6bac0 Mon Sep 17 00:00:00 2001
From: Eric Long <i@hack3r.moe>
Date: Sun, 14 Apr 2024 10:40:30 +0800
Subject: [PATCH] feat: Use type parameters to allow `{get,set}regset` to use
 different register set structs

---
 src/sys/ptrace/linux.rs | 111 ++++++++++++++++++++++++++++++----------
 test/sys/test_ptrace.rs |  33 +++++++-----
 2 files changed, 106 insertions(+), 38 deletions(-)

diff --git a/src/sys/ptrace/linux.rs b/src/sys/ptrace/linux.rs
index c36bf05197..ddd0a71fd0 100644
--- a/src/sys/ptrace/linux.rs
+++ b/src/sys/ptrace/linux.rs
@@ -172,21 +172,21 @@ libc_enum! {
     }
 }
 
+#[cfg(all(
+    target_os = "linux",
+    target_env = "gnu",
+    any(
+        target_arch = "x86_64",
+        target_arch = "x86",
+        target_arch = "aarch64",
+        target_arch = "riscv64",
+    )
+))]
 libc_enum! {
-    #[cfg(all(
-        target_os = "linux",
-        target_env = "gnu",
-        any(
-            target_arch = "x86_64",
-            target_arch = "x86",
-            target_arch = "aarch64",
-            target_arch = "riscv64",
-        )
-    ))]
     #[repr(i32)]
-    /// Defining a specific register set, as used in [`getregset`] and [`setregset`].
+    /// Defines a specific register set, as used in `PTRACE_GETREGSET` and `PTRACE_SETREGSET`.
     #[non_exhaustive]
-    pub enum RegisterSet {
+    pub enum RegisterSetValue {
         NT_PRSTATUS,
         NT_PRFPREG,
         NT_PRPSINFO,
@@ -195,6 +195,69 @@ libc_enum! {
     }
 }
 
+#[cfg(all(
+    target_os = "linux",
+    target_env = "gnu",
+    any(
+        target_arch = "x86_64",
+        target_arch = "x86",
+        target_arch = "aarch64",
+        target_arch = "riscv64",
+    )
+))]
+/// Represents register set areas, such as general-purpose registers or
+/// floating-point registers.
+///
+/// # Safety
+///
+/// This trait is marked unsafe, since implementation of the trait must match
+/// ptrace's request `VALUE` and return data type `Regs`.
+pub unsafe trait RegisterSet {
+    /// Corresponding type of registers in the kernel.
+    const VALUE: RegisterSetValue;
+
+    /// Struct representing the register space.
+    type Regs;
+}
+
+#[cfg(all(
+    target_os = "linux",
+    target_env = "gnu",
+    any(
+        target_arch = "x86_64",
+        target_arch = "x86",
+        target_arch = "aarch64",
+        target_arch = "riscv64",
+    )
+))]
+/// Register sets used in [`getregset`] and [`setregset`]
+pub mod regset {
+    use super::*;
+
+    #[derive(Debug, Clone, Copy)]
+    /// General-purpose registers.
+    pub struct NT_PRSTATUS;
+
+    unsafe impl RegisterSet for NT_PRSTATUS {
+        const VALUE: RegisterSetValue = RegisterSetValue::NT_PRSTATUS;
+        type Regs = user_regs_struct;
+    }
+
+    #[derive(Debug, Clone, Copy)]
+    /// Floating-point registers.
+    pub struct NT_PRFPREG;
+
+    unsafe impl RegisterSet for NT_PRFPREG {
+        const VALUE: RegisterSetValue = RegisterSetValue::NT_PRFPREG;
+        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+        type Regs = libc::user_fpregs_struct;
+        #[cfg(target_arch = "aarch64")]
+        type Regs = libc::user_fpsimd_struct;
+        #[cfg(target_arch = "riscv64")]
+        type Regs = libc::__riscv_mc_d_ext_state;
+    }
+}
+
 libc_bitflags! {
     /// Ptrace options used in conjunction with the PTRACE_SETOPTIONS request.
     /// See `man ptrace` for more details.
@@ -275,7 +338,7 @@ pub fn getregs(pid: Pid) -> Result<user_regs_struct> {
     any(target_arch = "aarch64", target_arch = "riscv64")
 ))]
 pub fn getregs(pid: Pid) -> Result<user_regs_struct> {
-    getregset(pid, RegisterSet::NT_PRSTATUS)
+    getregset::<regset::NT_PRSTATUS>(pid)
 }
 
 /// Get a particular set of user registers, as with `ptrace(PTRACE_GETREGSET, ...)`
@@ -289,18 +352,18 @@ pub fn getregs(pid: Pid) -> Result<user_regs_struct> {
         target_arch = "riscv64",
     )
 ))]
-pub fn getregset(pid: Pid, set: RegisterSet) -> Result<user_regs_struct> {
+pub fn getregset<S: RegisterSet>(pid: Pid) -> Result<S::Regs> {
     let request = Request::PTRACE_GETREGSET;
-    let mut data = mem::MaybeUninit::<user_regs_struct>::uninit();
+    let mut data = mem::MaybeUninit::<S::Regs>::uninit();
     let mut iov = libc::iovec {
         iov_base: data.as_mut_ptr().cast(),
-        iov_len: mem::size_of::<user_regs_struct>(),
+        iov_len: mem::size_of::<S::Regs>(),
     };
     unsafe {
         ptrace_other(
             request,
             pid,
-            set as i32 as AddressType,
+            S::VALUE as i32 as AddressType,
             (&mut iov as *mut libc::iovec).cast(),
         )?;
     };
@@ -349,7 +412,7 @@ pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> {
     any(target_arch = "aarch64", target_arch = "riscv64")
 ))]
 pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> {
-    setregset(pid, RegisterSet::NT_PRSTATUS, regs)
+    setregset::<regset::NT_PRSTATUS>(pid, regs)
 }
 
 /// Set a particular set of user registers, as with `ptrace(PTRACE_SETREGSET, ...)`
@@ -363,20 +426,16 @@ pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> {
         target_arch = "riscv64",
     )
 ))]
-pub fn setregset(
-    pid: Pid,
-    set: RegisterSet,
-    mut regs: user_regs_struct,
-) -> Result<()> {
+pub fn setregset<S: RegisterSet>(pid: Pid, mut regs: S::Regs) -> Result<()> {
     let mut iov = libc::iovec {
-        iov_base: (&mut regs as *mut user_regs_struct).cast(),
-        iov_len: mem::size_of::<user_regs_struct>(),
+        iov_base: (&mut regs as *mut S::Regs).cast(),
+        iov_len: mem::size_of::<S::Regs>(),
     };
     unsafe {
         ptrace_other(
             Request::PTRACE_SETREGSET,
             pid,
-            set as i32 as AddressType,
+            S::VALUE as i32 as AddressType,
             (&mut iov as *mut libc::iovec).cast(),
         )?;
     }
diff --git a/test/sys/test_ptrace.rs b/test/sys/test_ptrace.rs
index 5eb7e249f3..c99c6762c3 100644
--- a/test/sys/test_ptrace.rs
+++ b/test/sys/test_ptrace.rs
@@ -302,7 +302,7 @@ fn test_ptrace_syscall() {
 ))]
 #[test]
 fn test_ptrace_regsets() {
-    use nix::sys::ptrace::{self, getregset, setregset, RegisterSet};
+    use nix::sys::ptrace::{self, getregset, regset, setregset};
     use nix::sys::signal::*;
     use nix::sys::wait::{waitpid, WaitStatus};
     use nix::unistd::fork;
@@ -328,30 +328,39 @@ fn test_ptrace_regsets() {
                 Ok(WaitStatus::Stopped(child, Signal::SIGTRAP))
             );
             let mut regstruct =
-                getregset(child, RegisterSet::NT_PRSTATUS).unwrap();
+                getregset::<regset::NT_PRSTATUS>(child).unwrap();
+            let mut fpregstruct =
+                getregset::<regset::NT_PRFPREG>(child).unwrap();
 
             #[cfg(target_arch = "x86_64")]
-            let reg = &mut regstruct.r15;
+            let (reg, fpreg) =
+                (&mut regstruct.r15, &mut fpregstruct.st_space[5]);
             #[cfg(target_arch = "x86")]
-            let reg = &mut regstruct.edx;
+            let (reg, fpreg) =
+                (&mut regstruct.edx, &mut fpregstruct.st_space[5]);
             #[cfg(target_arch = "aarch64")]
-            let reg = &mut regstruct.regs[16];
+            let (reg, fpreg) =
+                (&mut regstruct.regs[16], &mut fpregstruct.vregs[5]);
             #[cfg(target_arch = "riscv64")]
-            let reg = &mut regstruct.regs[16];
+            let (reg, fpreg) = (&mut regstruct.t1, &mut fpregstruct.__f[5]);
 
             *reg = 0xdeadbeefu32 as _;
-            let _ = setregset(child, RegisterSet::NT_PRSTATUS, regstruct);
-            regstruct = getregset(child, RegisterSet::NT_PRSTATUS).unwrap();
+            *fpreg = 0xfeedfaceu32 as _;
+            let _ = setregset::<regset::NT_PRSTATUS>(child, regstruct);
+            regstruct = getregset::<regset::NT_PRSTATUS>(child).unwrap();
+            let _ = setregset::<regset::NT_PRFPREG>(child, fpregstruct);
+            fpregstruct = getregset::<regset::NT_PRFPREG>(child).unwrap();
 
             #[cfg(target_arch = "x86_64")]
-            let reg = regstruct.r15;
+            let (reg, fpreg) = (regstruct.r15, fpregstruct.st_space[5]);
             #[cfg(target_arch = "x86")]
-            let reg = regstruct.edx;
+            let (reg, fpreg) = (regstruct.edx, fpregstruct.st_space[5]);
             #[cfg(target_arch = "aarch64")]
-            let reg = regstruct.regs[16];
+            let (reg, fpreg) = (regstruct.regs[16], fpregstruct.vregs[5]);
             #[cfg(target_arch = "riscv64")]
-            let reg = regstruct.regs[16];
+            let (reg, fpreg) = (regstruct.t1, fpregstruct.__f[5]);
             assert_eq!(reg, 0xdeadbeefu32 as _);
+            assert_eq!(fpreg, 0xfeedfaceu32 as _);
 
             ptrace::cont(child, Some(Signal::SIGKILL)).unwrap();
             match waitpid(child, None) {