Skip to content

Commit

Permalink
feat(cndrv): 实现 context、mem、queue、notifier
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Jun 6, 2024
1 parent 454741f commit fb46d04
Show file tree
Hide file tree
Showing 6 changed files with 414 additions and 18 deletions.
29 changes: 14 additions & 15 deletions cndrv/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{bindings as cn, AsRaw, Device};
use std::{ffi::c_uint, ptr::null_mut};
use crate::{bindings as cn, AsRaw, Device, ResourceWrapper};
use std::{ffi::c_uint, marker::PhantomData, ptr::null_mut};

#[derive(PartialEq, Eq, Hash, Debug)]
pub struct Context {
Expand Down Expand Up @@ -74,24 +74,16 @@ impl Context {
pub fn apply<T>(&self, f: impl FnOnce(&ContextGuard) -> T) -> T {
f(&self.bound())
}

#[inline]
pub fn check_eq(
a: &impl AsRaw<Raw = cn::CNcontext>,
b: &impl AsRaw<Raw = cn::CNcontext>,
) -> bool {
unsafe { a.as_raw() == b.as_raw() }
}
}

#[repr(transparent)]
pub struct ContextGuard<'a>(&'a Context);
pub struct ContextGuard<'a>(cn::CNcontext, PhantomData<&'a ()>);

impl Context {
#[inline]
fn bound(&self) -> ContextGuard {
cndrv!(cnCtxSetCurrent(self.ctx));
ContextGuard(self)
ContextGuard(self.ctx, PhantomData)
}
}

Expand All @@ -104,7 +96,7 @@ impl Drop for ContextGuard<'_> {
cndrv!(cnCtxGetCurrent(&mut current));
current
},
self.0.ctx
self.0
);
cndrv!(cnCtxSetCurrent(null_mut()));
}
Expand All @@ -114,20 +106,27 @@ impl AsRaw for ContextGuard<'_> {
type Raw = cn::CNcontext;
#[inline]
unsafe fn as_raw(&self) -> Self::Raw {
self.0.ctx
self.0
}
}

impl ContextGuard<'_> {
#[inline]
pub fn dev(&self) -> Device {
Device(self.0.dev)
let mut dev = 0;
cndrv!(cnCtxGetDevice(&mut dev));
Device(dev)
}

#[inline]
pub fn synchronize(&self) {
cndrv!(cnCtxSync());
}

#[inline]
pub unsafe fn wrap_resource<T>(&self, res: T) -> ResourceWrapper<T> {
ResourceWrapper { ctx: self.0, res }
}
}

#[test]
Expand Down
14 changes: 12 additions & 2 deletions cndrv/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ pub mod bindings {
}
}

mod spore;

/// §4.3 Version Management
mod version;

Expand All @@ -41,10 +43,10 @@ mod context;
mod memory;

/// §4.7 Queue Management
// mod queue;
mod queue;

/// §4.8 Notifier Management
// mod notifier;
mod notifier;

/// §4.9 Atomic Operation Management
// mod atomic;
Expand Down Expand Up @@ -77,4 +79,12 @@ pub fn init() {
pub use context::{Context, ContextGuard};
pub use device::Device;
pub use memory::{memcpy_d2d, memcpy_d2h, memcpy_h2d, DevByte};
pub use notifier::{Notifier, NotifierSpore};
pub use queue::{Queue, QueueSpore};
pub use spore::{ContextResource, ContextSpore, ResourceWrapper};
pub use version::{driver_version, library_version, Version};

struct Blob<P> {
ptr: P,
len: usize,
}
190 changes: 189 additions & 1 deletion cndrv/src/memory.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
use std::mem::size_of_val;
use crate::{bindings as cn, impl_spore, AsRaw, Blob, ContextGuard, Queue};
use std::{
alloc::Layout,
ffi::c_void,
marker::PhantomData,
mem::size_of_val,
ops::{Deref, DerefMut},
ptr::null_mut,
slice::{from_raw_parts, from_raw_parts_mut},
};

#[repr(transparent)]
pub struct DevByte(#[allow(unused)] u8);
Expand All @@ -25,3 +34,182 @@ pub fn memcpy_d2d(dst: &mut [DevByte], src: &[DevByte]) {
assert_eq!(len, size_of_val(dst));
cndrv!(cnMemcpyDtoD(dst.as_ptr() as _, src.as_ptr() as _, len as _));
}

impl Queue<'_> {
#[inline]
pub fn memcpy_h2d<T: Copy>(&self, dst: &mut [DevByte], src: &[T]) {
let len = size_of_val(src);
let src = src.as_ptr().cast();
assert_eq!(len, size_of_val(dst));
cndrv!(cnMemcpyHtoDAsync_V2(
dst.as_ptr() as _,
src,
len as _,
self.as_raw()
));
}

#[inline]
pub fn memcpy_d2d(&self, dst: &mut [DevByte], src: &[DevByte]) {
let len = size_of_val(src);
assert_eq!(len, size_of_val(dst));
cndrv!(cnMemcpyDtoDAsync(
dst.as_ptr() as _,
src.as_ptr() as _,
len as _,
self.as_raw()
));
}
}

impl_spore!(DevMem and DevMemSpore by Blob<cn::CNaddr>);

impl ContextGuard<'_> {
pub fn malloc<T: Copy>(&self, len: usize) -> DevMem<'_> {
let len = Layout::array::<T>(len).unwrap().size();
let mut ptr = 0;
cndrv!(cnMalloc(&mut ptr, len as _));
DevMem(
unsafe { self.wrap_resource(Blob { ptr, len }) },
PhantomData,
)
}

pub fn from_host<T: Copy>(&self, slice: &[T]) -> DevMem<'_> {
let len = size_of_val(slice);
let src = slice.as_ptr().cast();
let mut ptr = 0;
cndrv!(cnMalloc(&mut ptr, len as _));
cndrv!(cnMemcpyHtoD(ptr, src, len as _));
DevMem(
unsafe { self.wrap_resource(Blob { ptr, len }) },
PhantomData,
)
}
}

impl Drop for DevMem<'_> {
#[inline]
fn drop(&mut self) {
cndrv!(cnFree(self.0.res.ptr));
}
}

impl Deref for DevMem<'_> {
type Target = [DevByte];
#[inline]
fn deref(&self) -> &Self::Target {
if self.0.res.len == 0 {
&[]
} else {
unsafe { from_raw_parts(self.0.res.ptr as _, self.0.res.len) }
}
}
}

impl DerefMut for DevMem<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
if self.0.res.len == 0 {
&mut []
} else {
unsafe { from_raw_parts_mut(self.0.res.ptr as _, self.0.res.len) }
}
}
}

impl AsRaw for DevMemSpore {
type Raw = cn::CNaddr;
#[inline]
unsafe fn as_raw(&self) -> Self::Raw {
self.0.res.ptr
}
}

impl DevMemSpore {
#[inline]
pub const fn len(&self) -> usize {
self.0.res.len
}

#[inline]
pub const fn is_empty(&self) -> bool {
self.0.res.len == 0
}
}

impl_spore!(HostMem and HostMemSpore by Blob<*mut c_void>);

impl<'ctx> ContextGuard<'ctx> {
pub fn malloc_host<T: Copy>(&'ctx self, len: usize) -> HostMem<'ctx> {
let len = Layout::array::<T>(len).unwrap().size();
let mut ptr = null_mut();
cndrv!(cnMallocHost(&mut ptr, len as _));
HostMem(
unsafe { self.wrap_resource(Blob { ptr, len }) },
PhantomData,
)
}
}

impl Drop for HostMem<'_> {
#[inline]
fn drop(&mut self) {
cndrv!(cnFreeHost(self.0.res.ptr));
}
}

impl AsRaw for HostMem<'_> {
type Raw = *mut c_void;
#[inline]
unsafe fn as_raw(&self) -> Self::Raw {
self.0.res.ptr
}
}

impl Deref for HostMem<'_> {
type Target = [u8];

#[inline]
fn deref(&self) -> &Self::Target {
unsafe { from_raw_parts(self.0.res.ptr.cast(), self.0.res.len) }
}
}

impl DerefMut for HostMem<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { from_raw_parts_mut(self.0.res.ptr.cast(), self.0.res.len) }
}
}

impl Deref for HostMemSpore {
type Target = [u8];

#[inline]
fn deref(&self) -> &Self::Target {
unsafe { from_raw_parts(self.0.res.ptr.cast(), self.0.res.len) }
}
}

impl DerefMut for HostMemSpore {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { from_raw_parts_mut(self.0.res.ptr.cast(), self.0.res.len) }
}
}

#[test]
fn test_behavior() {
crate::init();
let Some(dev) = crate::Device::fetch() else {
return;
};
let mut ptr = null_mut();
dev.context().apply(|_| {
cndrv!(cnMallocHost(&mut ptr, 128));
cndrv!(cnFreeHost(ptr));
});
ptr = null_mut();
cndrv!(cnFreeHost(ptr));
}
60 changes: 60 additions & 0 deletions cndrv/src/notifier.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use crate::{bindings as cn, impl_spore, AsRaw, Queue};
use std::{marker::PhantomData, ptr::null_mut, time::Duration};

impl_spore!(Notifier and NotifierSpore by cn::CNnotifier);

impl<'ctx> Queue<'ctx> {
pub fn record(&self) -> Notifier<'ctx> {
let mut event = null_mut();
cndrv!(cnCreateNotifier(
&mut event,
CNNotifierFlags::CN_NOTIFIER_DEFAULT as _
));
cndrv!(cnPlaceNotifier(event, self.as_raw()));
Notifier(unsafe { self.ctx().wrap_resource(event) }, PhantomData)
}
}

impl Drop for Notifier<'_> {
#[inline]
fn drop(&mut self) {
cndrv!(cnDestroyNotifier(self.0.res));
}
}

impl AsRaw for Notifier<'_> {
type Raw = cn::CNnotifier;
#[inline]
unsafe fn as_raw(&self) -> Self::Raw {
self.0.res
}
}

impl Queue<'_> {
pub fn bench(&self, f: impl Fn(usize, &Self), times: usize, warm_up: usize) -> Duration {
for i in 0..warm_up {
f(i, self);
}
let start = self.record();
for i in 0..times {
f(i, self);
}
let end = self.record();
end.synchronize();
end.elapse_from(&start).div_f32(times as _)
}
}

impl Notifier<'_> {
#[inline]
pub fn synchronize(&self) {
cndrv!(cnWaitNotifier(self.0.res));
}

#[inline]
pub fn elapse_from(&self, start: &Self) -> Duration {
let mut ms = 0.0;
cndrv!(cnNotifierElapsedTime(&mut ms, start.0.res, self.0.res));
Duration::from_secs_f32(ms * 1e-3)
}
}
Loading

0 comments on commit fb46d04

Please sign in to comment.