Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions constensor-core/src/dtype/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,109 @@ sqrt_integral!(u32);
sqrt_integral!(i32);
sqrt_integral!(i64);

pub trait Expable {
fn exp(&self) -> Self
where
Self: Sized;
fn exp2(&self) -> Self
where
Self: Sized;
}

impl Expable for f32 {
fn exp(&self) -> Self
where
Self: Sized,
{
f32::exp(*self)
}
fn exp2(&self) -> Self
where
Self: Sized,
{
f32::exp2(*self)
}
}

impl Expable for f64 {
fn exp(&self) -> Self
where
Self: Sized,
{
f64::exp(*self)
}
fn exp2(&self) -> Self
where
Self: Sized,
{
f64::exp2(*self)
}
}

macro_rules! exp_integral {
($t:ty) => {
impl Expable for $t {
fn exp(&self) -> Self
where
Self: Sized,
{
(*self as f64).exp() as $t
}
fn exp2(&self) -> Self
where
Self: Sized,
{
(*self as f64).exp2() as $t
}
}
};
}

exp_integral!(u8);
exp_integral!(u32);
exp_integral!(i32);
exp_integral!(i64);

#[cfg(feature = "bfloat")]
impl Expable for bf16 {
fn exp(&self) -> Self
where
Self: Sized,
{
bf16::from_f64_const(self.to_f64_const().exp())
}
fn exp2(&self) -> Self
where
Self: Sized,
{
bf16::from_f64_const(self.to_f64_const().exp2())
}
}

#[cfg(feature = "half")]
impl Expable for f16 {
fn exp(&self) -> Self
where
Self: Sized,
{
f16::from_f64_const(self.to_f64_const().exp())
}
fn exp2(&self) -> Self
where
Self: Sized,
{
f16::from_f64_const(self.to_f64_const().exp2())
}
}

pub trait DTypeOps:
Copy
+ Add<Output = Self>
+ Div<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Sqrtable
+ Expable
+ SimdSupported
+ GemmDispatch
+ RandDispatch
Expand Down
6 changes: 6 additions & 0 deletions constensor-core/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,20 +711,26 @@ impl BinaryOpType {
pub enum UnaryOpType {
Neg,
Sqrt,
Exp,
Exp2,
}

impl UnaryOpType {
pub fn fill_in_c_op(&self, val: impl Display) -> String {
match self {
Self::Neg => format!("-{val}"),
Self::Sqrt => format!("static_cast<T>( sqrt( static_cast<double>({val}) ) )"),
Self::Exp => format!("static_cast<T>( exp( static_cast<double>({val}) ) )"),
Self::Exp2 => format!("static_cast<T>( exp2( static_cast<double>({val}) ) )"),
}
}

pub fn to_closure<T: DType>(&self) -> impl Fn(T) -> T {
match self {
Self::Neg => T::maybe_neg,
Self::Sqrt => |x: T| x.sqrt(),
Self::Exp => |x: T| x.exp(),
Self::Exp2 => |x: T| x.exp2(),
}
}
}
Expand Down
40 changes: 40 additions & 0 deletions constensor-core/src/tensor/graphtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,46 @@ impl<S: Shape, T: DType, D: Dev> GraphTensor<S, T, D> {
}
}

#[must_use]
/// Elementwise unary exponential function.
pub fn exp(self) -> GraphTensor<S, T, D> {
let id = self.graph.write().unwrap().next_id();
self.graph.write().unwrap().add_op::<S>(
Op::UnaryOp {
v_id: self.id(),
operator: UnaryOpType::Exp,
},
&self.strides,
&id,
);
Self {
id,
graph: self.graph.clone(),
strides: self.strides.clone(),
_ghost: PhantomData,
}
}

#[must_use]
/// Elementwise unary base-2 exponential function.
pub fn exp2(self) -> GraphTensor<S, T, D> {
let id = self.graph.write().unwrap().next_id();
self.graph.write().unwrap().add_op::<S>(
Op::UnaryOp {
v_id: self.id(),
operator: UnaryOpType::Exp2,
},
&self.strides,
&id,
);
Self {
id,
graph: self.graph.clone(),
strides: self.strides.clone(),
_ghost: PhantomData,
}
}

#[must_use]
/// Create a tensor filled with uniform random values in [0,1).
pub fn rand(graph: &mut Graph<T>) -> Self {
Expand Down
32 changes: 32 additions & 0 deletions constensor-core/tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,38 @@ test_for_device_sqrt!(Cpu, cpu_tests_sqrt);
#[cfg(feature = "cuda")]
test_for_device_sqrt!(Cuda<0>, cuda_tests_sqrt);

macro_rules! test_for_device_exp {
($dev:ty, $name:ident) => {
mod $name {
use super::*;

#[test]
fn exp_float() {
let mut graph = Graph::empty();
let x = GraphTensor::<R2<3, 4>, f32, $dev>::fill(&mut graph, 0.0);
let _res = x.exp();
let compiled: CompiledGraph<R2<3, 4>, f32, $dev> = graph.compile().unwrap();
let tensor = compiled.run().unwrap();
assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![1.0; 4]; 3],);
}

#[test]
fn exp2_float() {
let mut graph = Graph::empty();
let x = GraphTensor::<R2<3, 4>, f32, $dev>::fill(&mut graph, 2.0);
let _res = x.exp2();
let compiled: CompiledGraph<R2<3, 4>, f32, $dev> = graph.compile().unwrap();
let tensor = compiled.run().unwrap();
assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![4.0; 4]; 3],);
}
}
};
}

test_for_device_exp!(Cpu, cpu_tests_exp);
#[cfg(feature = "cuda")]
test_for_device_exp!(Cuda, cuda_tests_exp);

macro_rules! test_for_device_rand {
($dev:ty, $name:ident) => {
mod $name {
Expand Down
Loading