Skip to content

Commit

Permalink
Merge pull request #5 from boydjohnson/feature/more-primitives
Browse files Browse the repository at this point in the history
Feature/more primitives
  • Loading branch information
boydjohnson authored Dec 23, 2024
2 parents c97cc06 + e186100 commit eff0134
Show file tree
Hide file tree
Showing 10 changed files with 347 additions and 122 deletions.
3 changes: 3 additions & 0 deletions src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,6 @@ impl Drop for Engine {
unsafe { dnnl_engine_destroy(self.handle) };
}
}

unsafe impl Sync for Engine {}
unsafe impl Send for Engine {}
3 changes: 3 additions & 0 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,6 @@ impl Drop for Memory {
unsafe { dnnl_memory_destroy(self.handle) };
}
}

unsafe impl Sync for Memory {}
unsafe impl Send for Memory {}
47 changes: 41 additions & 6 deletions src/memory/descriptor.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use onednnl_sys::{
dnnl_data_type_t, dnnl_dim_t, dnnl_memory_desc_clone, dnnl_memory_desc_create_with_blob,
dnnl_memory_desc_create_with_tag, dnnl_memory_desc_destroy, dnnl_memory_desc_equal,
dnnl_memory_desc_get_blob, dnnl_memory_desc_get_size, dnnl_memory_desc_t, dnnl_status_t,
dnnl_data_type_t, dnnl_dim_t, dnnl_format_tag_t::dnnl_format_tag_any, dnnl_memory_desc_clone,
dnnl_memory_desc_create_with_blob, dnnl_memory_desc_create_with_tag, dnnl_memory_desc_destroy,
dnnl_memory_desc_equal, dnnl_memory_desc_get_blob, dnnl_memory_desc_get_size,
dnnl_memory_desc_t, dnnl_status_t,
};

#[derive(Debug)]
pub struct MemoryDescriptor {
pub(crate) handle: dnnl_memory_desc_t,
}

use crate::error::DnnlError;

use super::format_tag::FormatTag;
use {super::format_tag::FormatTag, crate::error::DnnlError};

impl MemoryDescriptor {
/// Create a new MemoryDescriptor
Expand Down Expand Up @@ -63,6 +62,35 @@ impl MemoryDescriptor {
}
}

/// Create a new MemoryDescriptor
/// ```
/// use onednnl::memory::descriptor::MemoryDescriptor;
/// use onednnl_sys::dnnl_data_type_t::dnnl_f32;
///
///
/// let md = MemoryDescriptor::new_any(&[15, 15], dnnl_f32);
///
/// assert!(md.is_ok());
/// ```
pub fn new_any(dims: &[i64], data_type: dnnl_data_type_t::Type) -> Result<Self, DnnlError> {
let mut handle: dnnl_memory_desc_t = std::ptr::null_mut();
let status = unsafe {
dnnl_memory_desc_create_with_tag(
&mut handle,
dims.len() as i32,
dims.as_ptr(),
data_type,
dnnl_format_tag_any,
)
};

if status == dnnl_status_t::dnnl_success {
Ok(Self { handle })
} else {
Err(status.into())
}
}

/// Clones the memory descriptor.
///
/// ```
Expand Down Expand Up @@ -213,3 +241,10 @@ impl Drop for MemoryDescriptor {
}
}
}

pub struct DataType;

impl DataType {
pub const F32: dnnl_data_type_t::Type = dnnl_data_type_t::dnnl_f32;
pub const F64: dnnl_data_type_t::Type = dnnl_data_type_t::dnnl_f64;
}
134 changes: 22 additions & 112 deletions src/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use {
au_gru::{BackwardAuGruConfig, ForwardAuGruConfig},
batch_norm::ForwardBatchNormConfig,
binary::ForwardBinaryConfig,
eltwise::ForwardEltwiseConfig,
matmul::ForwardMatMulConfig,
PrimitiveConfig,
},
descriptor::PrimitiveDescriptor,
Expand Down Expand Up @@ -139,117 +141,25 @@ impl<'a, P: PropType<Forward>> Operation<'a, Forward, P> for ForwardBatchNorm<P>
type OperationConfig = ForwardBatchNormConfig<'a>;
}

// impl<D: Direction, P: PropType<D>> Operation<D, P> for BatchNorm<D, P> {
// const TYPE: OperationType = OperationType::BatchNormalization;

// pub type OperationConfig = BatchNormConfig<D, P>;
// }

// pub struct Binary<D: Direction, P: PropType<D>> {
// pub direction: D,
// pub prop_type: P,
// }

// impl<D: Direction, P: PropType<D>> Operation for Binary<D, P> {
// const TYPE: OperationType = OperationType::Binary;
// }

// impl_operation!(ForwardConcat, Direction::Forward, OperationType::Concat);
// impl_operation!(BackwardConcat, Direction::Backward, OperationType::Concat);

// impl_operation!(
// ForwardConvolution,
// Direction::Forward,
// OperationType::Convolution
// );
// impl_operation!(
// BackwardConvolution,
// Direction::Backward,
// OperationType::Convolution
// );

// impl_operation!(
// ForwardDeconvolution,
// Direction::Forward,
// OperationType::Deconvolution
// );
// impl_operation!(
// BackwardDeconvolution,
// Direction::Backward,
// OperationType::Deconvolution
// );

// impl_operation!(ForwardEltwise, Direction::Forward, OperationType::Eltwise);
// impl_operation!(BackwardEltwise, Direction::Backward, OperationType::Eltwise);

// impl_operation!(
// ForwardGroupNorm,
// Direction::Forward,
// OperationType::GroupNormalization
// );
// impl_operation!(
// BackwardGroupNorm,
// Direction::Backward,
// OperationType::GroupNormalization
// );

// impl_operation!(ForwardGru, Direction::Forward, OperationType::Gru);
// impl_operation!(BackwardGru, Direction::Backward, OperationType::Gru);

// impl_operation!(
// ForwardInnerProduct,
// Direction::Forward,
// OperationType::InnerProduct
// );
// impl_operation!(
// BackwardInnerProduct,
// Direction::Backward,
// OperationType::InnerProduct
// );

// impl_operation!(
// ForwardLayerNorm,
// Direction::Forward,
// OperationType::LayerNormalization
// );
// impl_operation!(
// BackwardLayerNorm,
// Direction::Backward,
// OperationType::LayerNormalization
// );

// impl_operation!(ForwardLbrAuGru, Direction::Forward, OperationType::LbrAuGru);
// impl_operation!(
// BackwardLbrAuGru,
// Direction::Backward,
// OperationType::LbrAuGru
// );

// impl_operation!(ForwardLrn, Direction::Forward, OperationType::Lrn);
// impl_operation!(BackwardLrn, Direction::Backward, OperationType::Lrn);

// impl_operation!(ForwardLstm, Direction::Forward, OperationType::Lstm);
// impl_operation!(BackwardLstm, Direction::Backward, OperationType::Lstm);

// impl_operation!(ForwardMatMul, Direction::Forward, OperationType::MatMul);
// impl_operation!(BackwardMatMul, Direction::Backward, OperationType::MatMul);

// impl_operation!(ForwardPRelu, Direction::Forward, OperationType::PRelu);
// impl_operation!(BackwardPRelu, Direction::Backward, OperationType::PRelu);

// impl_operation!(ForwardShuffle, Direction::Forward, OperationType::Shuffle);
// impl_operation!(BackwardShuffle, Direction::Backward, OperationType::Shuffle);

// impl_operation!(
// ForwardVanillaRnn,
// Direction::Forward,
// OperationType::VanillaRnn
// );
// impl_operation!(
// BackwardVanillaRnn,
// Direction::Backward,
// OperationType::VanillaRnn
// );
pub struct ForwardEltwise<P: PropType<Forward>> {
pub prop_type: P,
}

impl<'a, P: PropType<Forward>> Operation<'a, Forward, P> for ForwardEltwise<P> {
const TYPE: OperationType = OperationType::Eltwise;

type OperationConfig = ForwardEltwiseConfig<'a>;
}

pub struct ForwardMatMul<P: PropType<Forward>> {
pub prop_type: P,
}

impl<'a, P: PropType<Forward>> Operation<'a, Forward, P> for ForwardMatMul<P> {
const TYPE: OperationType = OperationType::MatMul;

type OperationConfig = ForwardMatMulConfig<'a>;
}

pub struct Primitive {
pub(crate) handle: dnnl_primitive_t,
Expand Down Expand Up @@ -321,7 +231,7 @@ impl Primitive {
}
}

pub fn execute(&self, stream: &Stream, args: Vec<ExecArg>) -> Result<(), DnnlError> {
pub fn execute(&self, stream: &Stream, args: Vec<ExecArg<'_>>) -> Result<(), DnnlError> {
let c_args: Vec<dnnl_exec_arg_t> = args
.iter()
.map(|arg| dnnl_exec_arg_t {
Expand Down
3 changes: 3 additions & 0 deletions src/primitive/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ use {
pub mod au_gru;
pub mod batch_norm;
pub mod binary;
pub mod eltwise;
pub mod inner_product;
pub mod matmul;

pub trait PrimitiveConfig<'a, D: Direction, P: PropType<D>> {
fn create_primitive_desc(&self, engine: Arc<Engine>) -> Result<PrimitiveDescriptor, DnnlError>;
Expand Down
17 changes: 17 additions & 0 deletions src/primitive/config/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,20 @@ impl<'a> PrimitiveConfig<'a, Forward, PropForwardInference> for ForwardBinaryCon
}
}
}

pub struct Binary;

impl Binary {
pub const ADD: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_add;
pub const DIV: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_div;
pub const EQ: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_eq;
pub const GT: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_gt;
pub const GE: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_ge;
pub const LE: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_le;
pub const LT: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_lt;
pub const MAX: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_max;
pub const MIN: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_min;
pub const MUL: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_mul;
pub const NE: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_ne;
pub const SUB: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_binary_sub;
}
48 changes: 48 additions & 0 deletions src/primitive/config/eltwise.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use {
super::PrimitiveConfig,
crate::{
memory::descriptor::MemoryDescriptor,
primitive::{descriptor::PrimitiveDescriptor, Forward, PropType},
},
onednnl_sys::{
dnnl_alg_kind_t, dnnl_eltwise_forward_primitive_desc_create, dnnl_primitive_attr_t,
dnnl_status_t,
},
};

pub struct ForwardEltwiseConfig<'a> {
pub alg_kind: dnnl_alg_kind_t::Type,
pub src_desc: &'a MemoryDescriptor,
pub dst_desc: &'a MemoryDescriptor,
pub alpha: f32,
pub beta: f32,
pub attr: dnnl_primitive_attr_t,
}

impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardEltwiseConfig<'a> {
fn create_primitive_desc(
&self,
engine: std::sync::Arc<crate::engine::Engine>,
) -> Result<crate::primitive::descriptor::PrimitiveDescriptor, crate::error::DnnlError> {
let mut handle = std::ptr::null_mut();
let status = unsafe {
dnnl_eltwise_forward_primitive_desc_create(
&mut handle,
engine.handle,
P::KIND,
self.alg_kind,
self.src_desc.handle,
self.dst_desc.handle,
self.alpha,
self.beta,
self.attr,
)
};

if status == dnnl_status_t::dnnl_success {
Ok(PrimitiveDescriptor { handle })
} else {
Err(status.into())
}
}
}
44 changes: 44 additions & 0 deletions src/primitive/config/inner_product.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use {
super::PrimitiveConfig,
crate::{
memory::descriptor::MemoryDescriptor,
primitive::{descriptor::PrimitiveDescriptor, Forward, PropType},
},
onednnl_sys::{
dnnl_inner_product_forward_primitive_desc_create, dnnl_primitive_attr_t, dnnl_status_t,
},
};

pub struct ForwardInnerProductConfig<'a> {
pub src_desc: &'a MemoryDescriptor,
pub weights_desc: &'a MemoryDescriptor,
pub bias_desc: &'a MemoryDescriptor,
pub dst_desc: &'a MemoryDescriptor,
pub attr: dnnl_primitive_attr_t,
}

impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardInnerProductConfig<'a> {
fn create_primitive_desc(
&self,
engine: std::sync::Arc<crate::engine::Engine>,
) -> Result<crate::primitive::descriptor::PrimitiveDescriptor, crate::error::DnnlError> {
let mut handle = std::ptr::null_mut();
let status = unsafe {
dnnl_inner_product_forward_primitive_desc_create(
&mut handle,
engine.handle,
P::KIND,
self.src_desc.handle,
self.weights_desc.handle,
self.bias_desc.handle,
self.dst_desc.handle,
self.attr,
)
};
if status == dnnl_status_t::dnnl_success {
Ok(PrimitiveDescriptor { handle })
} else {
Err(status.into())
}
}
}
Loading

0 comments on commit eff0134

Please sign in to comment.