Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use type parameters to allow {get,set}regset to use different register set structs #2373

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 85 additions & 26 deletions src/sys/ptrace/linux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,21 +172,21 @@ libc_enum! {
}
}

#[cfg(all(
target_os = "linux",
target_env = "gnu",
any(
target_arch = "x86_64",
target_arch = "x86",
target_arch = "aarch64",
target_arch = "riscv64",
)
))]
libc_enum! {
#[cfg(all(
target_os = "linux",
target_env = "gnu",
any(
target_arch = "x86_64",
target_arch = "x86",
target_arch = "aarch64",
target_arch = "riscv64",
)
))]
#[repr(i32)]
/// Defining a specific register set, as used in [`getregset`] and [`setregset`].
/// Defines a specific register set, as used in `PTRACE_GETREGSET` and `PTRACE_SETREGSET`.
#[non_exhaustive]
pub enum RegisterSet {
pub enum RegisterSetValue {
NT_PRSTATUS,
NT_PRFPREG,
NT_PRPSINFO,
Expand All @@ -195,6 +195,69 @@ libc_enum! {
}
}

#[cfg(all(
target_os = "linux",
target_env = "gnu",
any(
target_arch = "x86_64",
target_arch = "x86",
target_arch = "aarch64",
target_arch = "riscv64",
)
))]
/// Represents register set areas, such as general-purpose registers or
/// floating-point registers.
///
/// # Safety
///
/// This trait is marked unsafe, since implementation of the trait must match
/// ptrace's request `VALUE` and return data type `Regs`.
pub unsafe trait RegisterSet {
/// Corresponding type of registers in the kernel.
const VALUE: RegisterSetValue;

/// Struct representing the register space.
type Regs;
}

#[cfg(all(
target_os = "linux",
target_env = "gnu",
any(
target_arch = "x86_64",
target_arch = "x86",
target_arch = "aarch64",
target_arch = "riscv64",
)
))]
/// Register sets used in [`getregset`] and [`setregset`]
pub mod regset {
use super::*;

#[derive(Debug, Clone, Copy)]
/// General-purpose registers.
pub struct NT_PRSTATUS;

unsafe impl RegisterSet for NT_PRSTATUS {
const VALUE: RegisterSetValue = RegisterSetValue::NT_PRSTATUS;
type Regs = user_regs_struct;
}

#[derive(Debug, Clone, Copy)]
/// Floating-point registers.
pub struct NT_PRFPREG;

unsafe impl RegisterSet for NT_PRFPREG {
const VALUE: RegisterSetValue = RegisterSetValue::NT_PRFPREG;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
type Regs = libc::user_fpregs_struct;
#[cfg(target_arch = "aarch64")]
type Regs = libc::user_fpsimd_struct;
#[cfg(target_arch = "riscv64")]
type Regs = libc::__riscv_mc_d_ext_state;
}
}

libc_bitflags! {
/// Ptrace options used in conjunction with the PTRACE_SETOPTIONS request.
/// See `man ptrace` for more details.
Expand Down Expand Up @@ -275,7 +338,7 @@ pub fn getregs(pid: Pid) -> Result<user_regs_struct> {
any(target_arch = "aarch64", target_arch = "riscv64")
))]
pub fn getregs(pid: Pid) -> Result<user_regs_struct> {
getregset(pid, RegisterSet::NT_PRSTATUS)
getregset::<regset::NT_PRSTATUS>(pid)
}

/// Get a particular set of user registers, as with `ptrace(PTRACE_GETREGSET, ...)`
Expand All @@ -289,18 +352,18 @@ pub fn getregs(pid: Pid) -> Result<user_regs_struct> {
target_arch = "riscv64",
)
))]
pub fn getregset(pid: Pid, set: RegisterSet) -> Result<user_regs_struct> {
pub fn getregset<S: RegisterSet>(pid: Pid) -> Result<S::Regs> {
let request = Request::PTRACE_GETREGSET;
let mut data = mem::MaybeUninit::<user_regs_struct>::uninit();
let mut data = mem::MaybeUninit::<S::Regs>::uninit();
let mut iov = libc::iovec {
iov_base: data.as_mut_ptr().cast(),
iov_len: mem::size_of::<user_regs_struct>(),
iov_len: mem::size_of::<S::Regs>(),
};
unsafe {
ptrace_other(
request,
pid,
set as i32 as AddressType,
S::VALUE as i32 as AddressType,
(&mut iov as *mut libc::iovec).cast(),
)?;
};
Expand Down Expand Up @@ -349,7 +412,7 @@ pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> {
any(target_arch = "aarch64", target_arch = "riscv64")
))]
pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> {
setregset(pid, RegisterSet::NT_PRSTATUS, regs)
setregset::<regset::NT_PRSTATUS>(pid, regs)
}

/// Set a particular set of user registers, as with `ptrace(PTRACE_SETREGSET, ...)`
Expand All @@ -363,20 +426,16 @@ pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> {
target_arch = "riscv64",
)
))]
pub fn setregset(
pid: Pid,
set: RegisterSet,
mut regs: user_regs_struct,
) -> Result<()> {
pub fn setregset<S: RegisterSet>(pid: Pid, mut regs: S::Regs) -> Result<()> {
let mut iov = libc::iovec {
iov_base: (&mut regs as *mut user_regs_struct).cast(),
iov_len: mem::size_of::<user_regs_struct>(),
iov_base: (&mut regs as *mut S::Regs).cast(),
iov_len: mem::size_of::<S::Regs>(),
};
unsafe {
ptrace_other(
Request::PTRACE_SETREGSET,
pid,
set as i32 as AddressType,
S::VALUE as i32 as AddressType,
(&mut iov as *mut libc::iovec).cast(),
)?;
}
Expand Down
33 changes: 21 additions & 12 deletions test/sys/test_ptrace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ fn test_ptrace_syscall() {
))]
#[test]
fn test_ptrace_regsets() {
use nix::sys::ptrace::{self, getregset, setregset, RegisterSet};
use nix::sys::ptrace::{self, getregset, regset, setregset};
use nix::sys::signal::*;
use nix::sys::wait::{waitpid, WaitStatus};
use nix::unistd::fork;
Expand All @@ -328,30 +328,39 @@ fn test_ptrace_regsets() {
Ok(WaitStatus::Stopped(child, Signal::SIGTRAP))
);
let mut regstruct =
getregset(child, RegisterSet::NT_PRSTATUS).unwrap();
getregset::<regset::NT_PRSTATUS>(child).unwrap();
let mut fpregstruct =
getregset::<regset::NT_PRFPREG>(child).unwrap();

#[cfg(target_arch = "x86_64")]
let reg = &mut regstruct.r15;
let (reg, fpreg) =
(&mut regstruct.r15, &mut fpregstruct.st_space[5]);
#[cfg(target_arch = "x86")]
let reg = &mut regstruct.edx;
let (reg, fpreg) =
(&mut regstruct.edx, &mut fpregstruct.st_space[5]);
#[cfg(target_arch = "aarch64")]
let reg = &mut regstruct.regs[16];
let (reg, fpreg) =
(&mut regstruct.regs[16], &mut fpregstruct.vregs[5]);
#[cfg(target_arch = "riscv64")]
let reg = &mut regstruct.regs[16];
let (reg, fpreg) = (&mut regstruct.t1, &mut fpregstruct.__f[5]);

*reg = 0xdeadbeefu32 as _;
let _ = setregset(child, RegisterSet::NT_PRSTATUS, regstruct);
regstruct = getregset(child, RegisterSet::NT_PRSTATUS).unwrap();
*fpreg = 0xfeedfaceu32 as _;
let _ = setregset::<regset::NT_PRSTATUS>(child, regstruct);
regstruct = getregset::<regset::NT_PRSTATUS>(child).unwrap();
let _ = setregset::<regset::NT_PRFPREG>(child, fpregstruct);
fpregstruct = getregset::<regset::NT_PRFPREG>(child).unwrap();

#[cfg(target_arch = "x86_64")]
let reg = regstruct.r15;
let (reg, fpreg) = (regstruct.r15, fpregstruct.st_space[5]);
#[cfg(target_arch = "x86")]
let reg = regstruct.edx;
let (reg, fpreg) = (regstruct.edx, fpregstruct.st_space[5]);
#[cfg(target_arch = "aarch64")]
let reg = regstruct.regs[16];
let (reg, fpreg) = (regstruct.regs[16], fpregstruct.vregs[5]);
#[cfg(target_arch = "riscv64")]
let reg = regstruct.regs[16];
let (reg, fpreg) = (regstruct.t1, fpregstruct.__f[5]);
assert_eq!(reg, 0xdeadbeefu32 as _);
assert_eq!(fpreg, 0xfeedfaceu32 as _);

ptrace::cont(child, Some(Signal::SIGKILL)).unwrap();
match waitpid(child, None) {
Expand Down