Skip to content

Feature/more primitives #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,6 @@ impl Drop for Engine {
unsafe { dnnl_engine_destroy(self.handle) };
}
}

unsafe impl Sync for Engine {}
unsafe impl Send for Engine {}
3 changes: 3 additions & 0 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,6 @@ impl Drop for Memory {
unsafe { dnnl_memory_destroy(self.handle) };
}
}

unsafe impl Sync for Memory {}
unsafe impl Send for Memory {}
47 changes: 41 additions & 6 deletions src/memory/descriptor.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use onednnl_sys::{
dnnl_data_type_t, dnnl_dim_t, dnnl_memory_desc_clone, dnnl_memory_desc_create_with_blob,
dnnl_memory_desc_create_with_tag, dnnl_memory_desc_destroy, dnnl_memory_desc_equal,
dnnl_memory_desc_get_blob, dnnl_memory_desc_get_size, dnnl_memory_desc_t, dnnl_status_t,
dnnl_data_type_t, dnnl_dim_t, dnnl_format_tag_t::dnnl_format_tag_any, dnnl_memory_desc_clone,
dnnl_memory_desc_create_with_blob, dnnl_memory_desc_create_with_tag, dnnl_memory_desc_destroy,
dnnl_memory_desc_equal, dnnl_memory_desc_get_blob, dnnl_memory_desc_get_size,
dnnl_memory_desc_t, dnnl_status_t,
};

#[derive(Debug)]
pub struct MemoryDescriptor {
pub(crate) handle: dnnl_memory_desc_t,
}

use crate::error::DnnlError;

use super::format_tag::FormatTag;
use {super::format_tag::FormatTag, crate::error::DnnlError};

impl MemoryDescriptor {
/// Create a new MemoryDescriptor
Expand Down Expand Up @@ -63,6 +62,35 @@ impl MemoryDescriptor {
}
}

/// Create a new MemoryDescriptor
/// ```
/// use onednnl::memory::descriptor::MemoryDescriptor;
/// use onednnl_sys::dnnl_data_type_t::dnnl_f32;
///
///
/// let md = MemoryDescriptor::new_any(&[15, 15], dnnl_f32);
///
/// assert!(md.is_ok());
/// ```
pub fn new_any(dims: &[i64], data_type: dnnl_data_type_t::Type) -> Result<Self, DnnlError> {
let mut handle: dnnl_memory_desc_t = std::ptr::null_mut();
let status = unsafe {
dnnl_memory_desc_create_with_tag(
&mut handle,
dims.len() as i32,
dims.as_ptr(),
data_type,
dnnl_format_tag_any,
)
};

if status == dnnl_status_t::dnnl_success {
Ok(Self { handle })
} else {
Err(status.into())
}
}

/// Clones the memory descriptor.
///
/// ```
Expand Down Expand Up @@ -213,3 +241,10 @@ impl Drop for MemoryDescriptor {
}
}
}

pub struct DataType;

impl DataType {
pub const F32: dnnl_data_type_t::Type = dnnl_data_type_t::dnnl_f32;
pub const F64: dnnl_data_type_t::Type = dnnl_data_type_t::dnnl_f64;
}
134 changes: 22 additions & 112 deletions src/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use {
au_gru::{BackwardAuGruConfig, ForwardAuGruConfig},
batch_norm::ForwardBatchNormConfig,
binary::ForwardBinaryConfig,
eltwise::ForwardEltwiseConfig,
matmul::ForwardMatMulConfig,
PrimitiveConfig,
},
descriptor::PrimitiveDescriptor,
Expand Down Expand Up @@ -139,117 +141,25 @@ impl<'a, P: PropType<Forward>> Operation<'a, Forward, P> for ForwardBatchNorm<P>
type OperationConfig = ForwardBatchNormConfig<'a>;
}

// impl<D: Direction, P: PropType<D>> Operation<D, P> for BatchNorm<D, P> {
// const TYPE: OperationType = OperationType::BatchNormalization;

// pub type OperationConfig = BatchNormConfig<D, P>;
// }

// pub struct Binary<D: Direction, P: PropType<D>> {
// pub direction: D,
// pub prop_type: P,
// }

// impl<D: Direction, P: PropType<D>> Operation for Binary<D, P> {
// const TYPE: OperationType = OperationType::Binary;
// }

// impl_operation!(ForwardConcat, Direction::Forward, OperationType::Concat);
// impl_operation!(BackwardConcat, Direction::Backward, OperationType::Concat);

// impl_operation!(
// ForwardConvolution,
// Direction::Forward,
// OperationType::Convolution
// );
// impl_operation!(
// BackwardConvolution,
// Direction::Backward,
// OperationType::Convolution
// );

// impl_operation!(
// ForwardDeconvolution,
// Direction::Forward,
// OperationType::Deconvolution
// );
// impl_operation!(
// BackwardDeconvolution,
// Direction::Backward,
// OperationType::Deconvolution
// );

// impl_operation!(ForwardEltwise, Direction::Forward, OperationType::Eltwise);
// impl_operation!(BackwardEltwise, Direction::Backward, OperationType::Eltwise);

// impl_operation!(
// ForwardGroupNorm,
// Direction::Forward,
// OperationType::GroupNormalization
// );
// impl_operation!(
// BackwardGroupNorm,
// Direction::Backward,
// OperationType::GroupNormalization
// );

// impl_operation!(ForwardGru, Direction::Forward, OperationType::Gru);
// impl_operation!(BackwardGru, Direction::Backward, OperationType::Gru);

// impl_operation!(
// ForwardInnerProduct,
// Direction::Forward,
// OperationType::InnerProduct
// );
// impl_operation!(
// BackwardInnerProduct,
// Direction::Backward,
// OperationType::InnerProduct
// );

// impl_operation!(
// ForwardLayerNorm,
// Direction::Forward,
// OperationType::LayerNormalization
// );
// impl_operation!(
// BackwardLayerNorm,
// Direction::Backward,
// OperationType::LayerNormalization
// );

// impl_operation!(ForwardLbrAuGru, Direction::Forward, OperationType::LbrAuGru);
// impl_operation!(
// BackwardLbrAuGru,
// Direction::Backward,
// OperationType::LbrAuGru
// );

// impl_operation!(ForwardLrn, Direction::Forward, OperationType::Lrn);
// impl_operation!(BackwardLrn, Direction::Backward, OperationType::Lrn);

// impl_operation!(ForwardLstm, Direction::Forward, OperationType::Lstm);
// impl_operation!(BackwardLstm, Direction::Backward, OperationType::Lstm);

// impl_operation!(ForwardMatMul, Direction::Forward, OperationType::MatMul);
// impl_operation!(BackwardMatMul, Direction::Backward, OperationType::MatMul);

// impl_operation!(ForwardPRelu, Direction::Forward, OperationType::PRelu);
// impl_operation!(BackwardPRelu, Direction::Backward, OperationType::PRelu);

// impl_operation!(ForwardShuffle, Direction::Forward, OperationType::Shuffle);
// impl_operation!(BackwardShuffle, Direction::Backward, OperationType::Shuffle);

// impl_operation!(
// ForwardVanillaRnn,
// Direction::Forward,
// OperationType::VanillaRnn
// );
// impl_operation!(
// BackwardVanillaRnn,
// Direction::Backward,
// OperationType::VanillaRnn
// );
pub struct ForwardEltwise<P: PropType<Forward>> {
pub prop_type: P,
}

impl<'a, P: PropType<Forward>> Operation<'a, Forward, P> for ForwardEltwise<P> {
const TYPE: OperationType = OperationType::Eltwise;

type OperationConfig = ForwardEltwiseConfig<'a>;
}

pub struct ForwardMatMul<P: PropType<Forward>> {
pub prop_type: P,
}

impl<'a, P: PropType<Forward>> Operation<'a, Forward, P> for ForwardMatMul<P> {
const TYPE: OperationType = OperationType::MatMul;

type OperationConfig = ForwardMatMulConfig<'a>;
}

pub struct Primitive {
pub(crate) handle: dnnl_primitive_t,
Expand Down Expand Up @@ -321,7 +231,7 @@ impl Primitive {
}
}

pub fn execute(&self, stream: &Stream, args: Vec<ExecArg>) -> Result<(), DnnlError> {
pub fn execute(&self, stream: &Stream, args: Vec<ExecArg<'_>>) -> Result<(), DnnlError> {
let c_args: Vec<dnnl_exec_arg_t> = args
.iter()
.map(|arg| dnnl_exec_arg_t {
Expand Down
3 changes: 3 additions & 0 deletions src/primitive/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ use {
pub mod au_gru;
pub mod batch_norm;
pub mod binary;
pub mod eltwise;
pub mod inner_product;
pub mod matmul;

pub trait PrimitiveConfig<'a, D: Direction, P: PropType<D>> {
fn create_primitive_desc(&self, engine: Arc<Engine>) -> Result<PrimitiveDescriptor, DnnlError>;
Expand Down
17 changes: 17 additions & 0 deletions src/primitive/config/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,20 @@ impl<'a> PrimitiveConfig<'a, Forward, PropForwardInference> for ForwardBinaryCon
}
}
}

pub struct Binary;

impl Binary {
pub const ADD: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_add;
pub const DIV: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_div;
pub const EQ: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_eq;
pub const GT: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_gt;
pub const GE: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_ge;
pub const LE: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_le;
pub const LT: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_lt;
pub const MAX: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_max;
pub const MIN: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_min;
pub const MUL: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_mul;
pub const NE: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_ne;
pub const SUB: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_sub;
}
48 changes: 48 additions & 0 deletions src/primitive/config/eltwise.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use {
super::PrimitiveConfig,
crate::{
memory::descriptor::MemoryDescriptor,
primitive::{descriptor::PrimitiveDescriptor, Forward, PropType},
},
onednnl_sys::{
dnnl_alg_kind_t, dnnl_eltwise_forward_primitive_desc_create, dnnl_primitive_attr_t,
dnnl_status_t,
},
};

pub struct ForwardEltwiseConfig<'a> {
pub alg_kind: dnnl_alg_kind_t::Type,
pub src_desc: &'a MemoryDescriptor,
pub dst_desc: &'a MemoryDescriptor,
pub alpha: f32,
pub beta: f32,
pub attr: dnnl_primitive_attr_t,
}

impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardEltwiseConfig<'a> {
fn create_primitive_desc(
&self,
engine: std::sync::Arc<crate::engine::Engine>,
) -> Result<crate::primitive::descriptor::PrimitiveDescriptor, crate::error::DnnlError> {
let mut handle = std::ptr::null_mut();
let status = unsafe {
dnnl_eltwise_forward_primitive_desc_create(
&mut handle,
engine.handle,
P::KIND,
self.alg_kind,
self.src_desc.handle,
self.dst_desc.handle,
self.alpha,
self.beta,
self.attr,
)
};

if status == dnnl_status_t::dnnl_success {
Ok(PrimitiveDescriptor { handle })
} else {
Err(status.into())
}
}
}
44 changes: 44 additions & 0 deletions src/primitive/config/inner_product.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use {
super::PrimitiveConfig,
crate::{
memory::descriptor::MemoryDescriptor,
primitive::{descriptor::PrimitiveDescriptor, Forward, PropType},
},
onednnl_sys::{
dnnl_inner_product_forward_primitive_desc_create, dnnl_primitive_attr_t, dnnl_status_t,
},
};

pub struct ForwardInnerProductConfig<'a> {
pub src_desc: &'a MemoryDescriptor,
pub weights_desc: &'a MemoryDescriptor,
pub bias_desc: &'a MemoryDescriptor,
pub dst_desc: &'a MemoryDescriptor,
pub attr: dnnl_primitive_attr_t,
}

impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardInnerProductConfig<'a> {
fn create_primitive_desc(
&self,
engine: std::sync::Arc<crate::engine::Engine>,
) -> Result<crate::primitive::descriptor::PrimitiveDescriptor, crate::error::DnnlError> {
let mut handle = std::ptr::null_mut();
let status = unsafe {
dnnl_inner_product_forward_primitive_desc_create(
&mut handle,
engine.handle,
P::KIND,
self.src_desc.handle,
self.weights_desc.handle,
self.bias_desc.handle,
self.dst_desc.handle,
self.attr,
)
};
if status == dnnl_status_t::dnnl_success {
Ok(PrimitiveDescriptor { handle })
} else {
Err(status.into())
}
}
}
Loading