Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions constensor-core/src/dtype/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,101 @@ impl Expable for f16 {
}
}

pub trait Loggable {
fn log(&self) -> Self
where
Self: Sized;
fn log1p(&self) -> Self
where
Self: Sized;
}

impl Loggable for f32 {
fn log(&self) -> Self
where
Self: Sized,
{
f32::ln(*self)
}
fn log1p(&self) -> Self
where
Self: Sized,
{
f32::ln_1p(*self)
}
}

impl Loggable for f64 {
fn log(&self) -> Self
where
Self: Sized,
{
f64::ln(*self)
}
fn log1p(&self) -> Self
where
Self: Sized,
{
f64::ln_1p(*self)
}
}

macro_rules! log_integral {
($t:ty) => {
impl Loggable for $t {
fn log(&self) -> Self
where
Self: Sized,
{
(*self as f64).ln() as $t
}
fn log1p(&self) -> Self
where
Self: Sized,
{
(*self as f64).ln_1p() as $t
}
}
};
}

log_integral!(u8);
log_integral!(u32);
log_integral!(i32);
log_integral!(i64);

#[cfg(feature = "bfloat")]
impl Loggable for bf16 {
fn log(&self) -> Self
where
Self: Sized,
{
bf16::from_f64_const(self.to_f64_const().ln())
}
fn log1p(&self) -> Self
where
Self: Sized,
{
bf16::from_f64_const(self.to_f64_const().ln_1p())
}
}

#[cfg(feature = "half")]
impl Loggable for f16 {
fn log(&self) -> Self
where
Self: Sized,
{
f16::from_f64_const(self.to_f64_const().ln())
}
fn log1p(&self) -> Self
where
Self: Sized,
{
f16::from_f64_const(self.to_f64_const().ln_1p())
}
}

pub trait DTypeOps:
Copy
+ Add<Output = Self>
Expand All @@ -186,6 +281,7 @@ pub trait DTypeOps:
+ Mul<Output = Self>
+ Sqrtable
+ Expable
+ Loggable
+ SimdSupported
+ GemmDispatch
+ RandDispatch
Expand Down
6 changes: 6 additions & 0 deletions constensor-core/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,8 @@ pub enum UnaryOpType {
Sqrt,
Exp,
Exp2,
Log,
Log1p,
}

impl UnaryOpType {
Expand All @@ -722,6 +724,8 @@ impl UnaryOpType {
Self::Sqrt => format!("static_cast<T>( sqrt( static_cast<double>({val}) ) )"),
Self::Exp => format!("static_cast<T>( exp( static_cast<double>({val}) ) )"),
Self::Exp2 => format!("static_cast<T>( exp2( static_cast<double>({val}) ) )"),
Self::Log => format!("static_cast<T>( log( static_cast<double>({val}) ) )"),
Self::Log1p => format!("static_cast<T>( log1p( static_cast<double>({val}) ) )"),
}
}

Expand All @@ -731,6 +735,8 @@ impl UnaryOpType {
Self::Sqrt => |x: T| x.sqrt(),
Self::Exp => |x: T| x.exp(),
Self::Exp2 => |x: T| x.exp2(),
Self::Log => |x: T| x.log(),
Self::Log1p => |x: T| x.log1p(),
}
}
}
Expand Down
40 changes: 40 additions & 0 deletions constensor-core/src/tensor/graphtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,46 @@ impl<S: Shape, T: DType, D: Dev> GraphTensor<S, T, D> {
}
}

#[must_use]
/// Elementwise unary natural logarithm function.
pub fn log(self) -> GraphTensor<S, T, D> {
let id = self.graph.write().unwrap().next_id();
self.graph.write().unwrap().add_op::<S>(
Op::UnaryOp {
v_id: self.id(),
operator: UnaryOpType::Log,
},
&self.strides,
&id,
);
Self {
id,
graph: self.graph.clone(),
strides: self.strides.clone(),
_ghost: PhantomData,
}
}

#[must_use]
/// Elementwise unary natural logarithm of (1+x) function.
pub fn log1p(self) -> GraphTensor<S, T, D> {
let id = self.graph.write().unwrap().next_id();
self.graph.write().unwrap().add_op::<S>(
Op::UnaryOp {
v_id: self.id(),
operator: UnaryOpType::Log1p,
},
&self.strides,
&id,
);
Self {
id,
graph: self.graph.clone(),
strides: self.strides.clone(),
_ghost: PhantomData,
}
}

#[must_use]
/// Create a tensor filled with uniform random values in [0,1).
pub fn rand(graph: &mut Graph<T>) -> Self {
Expand Down
48 changes: 47 additions & 1 deletion constensor-core/tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,53 @@ macro_rules! test_for_device_exp {

test_for_device_exp!(Cpu, cpu_tests_exp);
#[cfg(feature = "cuda")]
test_for_device_exp!(Cuda, cuda_tests_exp);
test_for_device_exp!(Cuda<0>, cuda_tests_exp);

macro_rules! test_for_device_log {
($dev:ty, $name:ident) => {
mod $name {
use super::*;

#[test]
fn log_float() {
let mut graph = Graph::empty();
let x = GraphTensor::<R2<3, 4>, f32, $dev>::fill(&mut graph, 1.0);
let _res = x.log();
let compiled: CompiledGraph<R2<3, 4>, f32, $dev> = graph.compile().unwrap();
let tensor = compiled.run().unwrap();
assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![0.0; 4]; 3],);
}

#[test]
fn log1p_float() {
let mut graph = Graph::empty();
let x = GraphTensor::<R2<3, 4>, f32, $dev>::fill(&mut graph, 0.0);
let _res = x.log1p();
let compiled: CompiledGraph<R2<3, 4>, f32, $dev> = graph.compile().unwrap();
let tensor = compiled.run().unwrap();
assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![0.0; 4]; 3],);
}

#[test]
fn log_e_float() {
let mut graph = Graph::empty();
let x = GraphTensor::<R2<3, 4>, f32, $dev>::fill(&mut graph, std::f32::consts::E);
let _res = x.log();
let compiled: CompiledGraph<R2<3, 4>, f32, $dev> = graph.compile().unwrap();
let tensor = compiled.run().unwrap();
for row in tensor.data().unwrap().iter() {
for &val in row.iter() {
assert!((val - 1.0).abs() < 1e-6);
}
}
}
}
};
}

test_for_device_log!(Cpu, cpu_tests_log);
#[cfg(feature = "cuda")]
test_for_device_log!(Cuda<0>, cuda_tests_log);

macro_rules! test_for_device_rand {
($dev:ty, $name:ident) => {
Expand Down
Loading