diff --git a/constensor-core/src/dtype/mod.rs b/constensor-core/src/dtype/mod.rs index a0e8ec6..3529fc2 100644 --- a/constensor-core/src/dtype/mod.rs +++ b/constensor-core/src/dtype/mod.rs @@ -178,6 +178,101 @@ impl Expable for f16 { } } +pub trait Loggable { + fn log(&self) -> Self + where + Self: Sized; + fn log1p(&self) -> Self + where + Self: Sized; +} + +impl Loggable for f32 { + fn log(&self) -> Self + where + Self: Sized, + { + f32::ln(*self) + } + fn log1p(&self) -> Self + where + Self: Sized, + { + f32::ln_1p(*self) + } +} + +impl Loggable for f64 { + fn log(&self) -> Self + where + Self: Sized, + { + f64::ln(*self) + } + fn log1p(&self) -> Self + where + Self: Sized, + { + f64::ln_1p(*self) + } +} + +macro_rules! log_integral { + ($t:ty) => { + impl Loggable for $t { + fn log(&self) -> Self + where + Self: Sized, + { + (*self as f64).ln() as $t + } + fn log1p(&self) -> Self + where + Self: Sized, + { + (*self as f64).ln_1p() as $t + } + } + }; +} + +log_integral!(u8); +log_integral!(u32); +log_integral!(i32); +log_integral!(i64); + +#[cfg(feature = "bfloat")] +impl Loggable for bf16 { + fn log(&self) -> Self + where + Self: Sized, + { + bf16::from_f64_const(self.to_f64_const().ln()) + } + fn log1p(&self) -> Self + where + Self: Sized, + { + bf16::from_f64_const(self.to_f64_const().ln_1p()) + } +} + +#[cfg(feature = "half")] +impl Loggable for f16 { + fn log(&self) -> Self + where + Self: Sized, + { + f16::from_f64_const(self.to_f64_const().ln()) + } + fn log1p(&self) -> Self + where + Self: Sized, + { + f16::from_f64_const(self.to_f64_const().ln_1p()) + } +} + pub trait DTypeOps: Copy + Add @@ -186,6 +281,7 @@ pub trait DTypeOps: + Mul + Sqrtable + Expable + + Loggable + SimdSupported + GemmDispatch + RandDispatch diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index 8a0b69d..ab7f4f2 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -713,6 +713,8 @@ pub enum UnaryOpType { Sqrt, Exp, Exp2, + Log, + Log1p, } impl UnaryOpType { @@ -722,6 +724,8 @@ impl UnaryOpType { Self::Sqrt => format!("static_cast( sqrt( static_cast({val}) ) )"), Self::Exp => format!("static_cast( exp( static_cast({val}) ) )"), Self::Exp2 => format!("static_cast( exp2( static_cast({val}) ) )"), + Self::Log => format!("static_cast( log( static_cast({val}) ) )"), + Self::Log1p => format!("static_cast( log1p( static_cast({val}) ) )"), } } @@ -731,6 +735,8 @@ impl UnaryOpType { Self::Sqrt => |x: T| x.sqrt(), Self::Exp => |x: T| x.exp(), Self::Exp2 => |x: T| x.exp2(), + Self::Log => |x: T| x.log(), + Self::Log1p => |x: T| x.log1p(), } } } diff --git a/constensor-core/src/tensor/graphtensor.rs b/constensor-core/src/tensor/graphtensor.rs index aba8af3..452f58d 100644 --- a/constensor-core/src/tensor/graphtensor.rs +++ b/constensor-core/src/tensor/graphtensor.rs @@ -171,6 +171,46 @@ impl GraphTensor { } } + #[must_use] + /// Elementwise unary natural logarithm function. + pub fn log(self) -> GraphTensor { + let id = self.graph.write().unwrap().next_id(); + self.graph.write().unwrap().add_op::( + Op::UnaryOp { + v_id: self.id(), + operator: UnaryOpType::Log, + }, + &self.strides, + &id, + ); + Self { + id, + graph: self.graph.clone(), + strides: self.strides.clone(), + _ghost: PhantomData, + } + } + + #[must_use] + /// Elementwise unary natural logarithm of (1+x) function. + pub fn log1p(self) -> GraphTensor { + let id = self.graph.write().unwrap().next_id(); + self.graph.write().unwrap().add_op::( + Op::UnaryOp { + v_id: self.id(), + operator: UnaryOpType::Log1p, + }, + &self.strides, + &id, + ); + Self { + id, + graph: self.graph.clone(), + strides: self.strides.clone(), + _ghost: PhantomData, + } + } + #[must_use] /// Create a tensor filled with uniform random values in [0,1). pub fn rand(graph: &mut Graph) -> Self { diff --git a/constensor-core/tests/ops.rs b/constensor-core/tests/ops.rs index 633353e..c75a87a 100644 --- a/constensor-core/tests/ops.rs +++ b/constensor-core/tests/ops.rs @@ -400,7 +400,53 @@ macro_rules! test_for_device_exp { test_for_device_exp!(Cpu, cpu_tests_exp); #[cfg(feature = "cuda")] -test_for_device_exp!(Cuda, cuda_tests_exp); +test_for_device_exp!(Cuda<0>, cuda_tests_exp); + +macro_rules! test_for_device_log { + ($dev:ty, $name:ident) => { + mod $name { + use super::*; + + #[test] + fn log_float() { + let mut graph = Graph::empty(); + let x = GraphTensor::, f32, $dev>::fill(&mut graph, 1.0); + let _res = x.log(); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); + assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![0.0; 4]; 3],); + } + + #[test] + fn log1p_float() { + let mut graph = Graph::empty(); + let x = GraphTensor::, f32, $dev>::fill(&mut graph, 0.0); + let _res = x.log1p(); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); + assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![0.0; 4]; 3],); + } + + #[test] + fn log_e_float() { + let mut graph = Graph::empty(); + let x = GraphTensor::, f32, $dev>::fill(&mut graph, std::f32::consts::E); + let _res = x.log(); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); + for row in tensor.data().unwrap().iter() { + for &val in row.iter() { + assert!((val - 1.0).abs() < 1e-6); + } + } + } + } + }; +} + +test_for_device_log!(Cpu, cpu_tests_log); +#[cfg(feature = "cuda")] +test_for_device_log!(Cuda<0>, cuda_tests_log); macro_rules! test_for_device_rand { ($dev:ty, $name:ident) => {