From fd5e938b9d23acdd7783c98b79bb3abec66bcf36 Mon Sep 17 00:00:00 2001 From: xiaoniaoyouhuajiang <2583473505@qq.com> Date: Sun, 27 Apr 2025 15:08:50 +0800 Subject: [PATCH] support exp,exp2 unary op --- constensor-core/src/dtype/mod.rs | 96 +++++++++++++++++++++++ constensor-core/src/graph.rs | 6 ++ constensor-core/src/tensor/graphtensor.rs | 40 ++++++++++ constensor-core/tests/ops.rs | 32 ++++++++ 4 files changed, 174 insertions(+) diff --git a/constensor-core/src/dtype/mod.rs b/constensor-core/src/dtype/mod.rs index bee9dc5..a0e8ec6 100644 --- a/constensor-core/src/dtype/mod.rs +++ b/constensor-core/src/dtype/mod.rs @@ -83,6 +83,101 @@ sqrt_integral!(u32); sqrt_integral!(i32); sqrt_integral!(i64); +pub trait Expable { + fn exp(&self) -> Self + where + Self: Sized; + fn exp2(&self) -> Self + where + Self: Sized; +} + +impl Expable for f32 { + fn exp(&self) -> Self + where + Self: Sized, + { + f32::exp(*self) + } + fn exp2(&self) -> Self + where + Self: Sized, + { + f32::exp2(*self) + } +} + +impl Expable for f64 { + fn exp(&self) -> Self + where + Self: Sized, + { + f64::exp(*self) + } + fn exp2(&self) -> Self + where + Self: Sized, + { + f64::exp2(*self) + } +} + +macro_rules! exp_integral { + ($t:ty) => { + impl Expable for $t { + fn exp(&self) -> Self + where + Self: Sized, + { + (*self as f64).exp() as $t + } + fn exp2(&self) -> Self + where + Self: Sized, + { + (*self as f64).exp2() as $t + } + } + }; +} + +exp_integral!(u8); +exp_integral!(u32); +exp_integral!(i32); +exp_integral!(i64); + +#[cfg(feature = "bfloat")] +impl Expable for bf16 { + fn exp(&self) -> Self + where + Self: Sized, + { + bf16::from_f64_const(self.to_f64_const().exp()) + } + fn exp2(&self) -> Self + where + Self: Sized, + { + bf16::from_f64_const(self.to_f64_const().exp2()) + } +} + +#[cfg(feature = "half")] +impl Expable for f16 { + fn exp(&self) -> Self + where + Self: Sized, + { + f16::from_f64_const(self.to_f64_const().exp()) + } + fn exp2(&self) -> Self + where + Self: Sized, + { + f16::from_f64_const(self.to_f64_const().exp2()) + } +} + pub trait DTypeOps: Copy + Add @@ -90,6 +185,7 @@ pub trait DTypeOps: + Sub + Mul + Sqrtable + + Expable + SimdSupported + GemmDispatch + RandDispatch diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index 6a9bd68..8a0b69d 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -711,6 +711,8 @@ impl BinaryOpType { pub enum UnaryOpType { Neg, Sqrt, + Exp, + Exp2, } impl UnaryOpType { @@ -718,6 +720,8 @@ impl UnaryOpType { match self { Self::Neg => format!("-{val}"), Self::Sqrt => format!("static_cast( sqrt( static_cast({val}) ) )"), + Self::Exp => format!("static_cast( exp( static_cast({val}) ) )"), + Self::Exp2 => format!("static_cast( exp2( static_cast({val}) ) )"), } } @@ -725,6 +729,8 @@ impl UnaryOpType { match self { Self::Neg => T::maybe_neg, Self::Sqrt => |x: T| x.sqrt(), + Self::Exp => |x: T| x.exp(), + Self::Exp2 => |x: T| x.exp2(), } } } diff --git a/constensor-core/src/tensor/graphtensor.rs b/constensor-core/src/tensor/graphtensor.rs index 006ac18..aba8af3 100644 --- a/constensor-core/src/tensor/graphtensor.rs +++ b/constensor-core/src/tensor/graphtensor.rs @@ -131,6 +131,46 @@ impl GraphTensor { } } + #[must_use] + /// Elementwise unary exponential function. + pub fn exp(self) -> GraphTensor { + let id = self.graph.write().unwrap().next_id(); + self.graph.write().unwrap().add_op::( + Op::UnaryOp { + v_id: self.id(), + operator: UnaryOpType::Exp, + }, + &self.strides, + &id, + ); + Self { + id, + graph: self.graph.clone(), + strides: self.strides.clone(), + _ghost: PhantomData, + } + } + + #[must_use] + /// Elementwise unary base-2 exponential function. + pub fn exp2(self) -> GraphTensor { + let id = self.graph.write().unwrap().next_id(); + self.graph.write().unwrap().add_op::( + Op::UnaryOp { + v_id: self.id(), + operator: UnaryOpType::Exp2, + }, + &self.strides, + &id, + ); + Self { + id, + graph: self.graph.clone(), + strides: self.strides.clone(), + _ghost: PhantomData, + } + } + #[must_use] /// Create a tensor filled with uniform random values in [0,1). pub fn rand(graph: &mut Graph) -> Self { diff --git a/constensor-core/tests/ops.rs b/constensor-core/tests/ops.rs index aaf9f03..633353e 100644 --- a/constensor-core/tests/ops.rs +++ b/constensor-core/tests/ops.rs @@ -370,6 +370,38 @@ test_for_device_sqrt!(Cpu, cpu_tests_sqrt); #[cfg(feature = "cuda")] test_for_device_sqrt!(Cuda<0>, cuda_tests_sqrt); +macro_rules! test_for_device_exp { + ($dev:ty, $name:ident) => { + mod $name { + use super::*; + + #[test] + fn exp_float() { + let mut graph = Graph::empty(); + let x = GraphTensor::, f32, $dev>::fill(&mut graph, 0.0); + let _res = x.exp(); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); + assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![1.0; 4]; 3],); + } + + #[test] + fn exp2_float() { + let mut graph = Graph::empty(); + let x = GraphTensor::, f32, $dev>::fill(&mut graph, 2.0); + let _res = x.exp2(); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); + assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![4.0; 4]; 3],); + } + } + }; +} + +test_for_device_exp!(Cpu, cpu_tests_exp); +#[cfg(feature = "cuda")] +test_for_device_exp!(Cuda, cuda_tests_exp); + macro_rules! test_for_device_rand { ($dev:ty, $name:ident) => { mod $name {