diff --git a/Cargo.lock b/Cargo.lock index ddd7ce2..c27ec81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -12,89 +12,17 @@ dependencies = [ ] [[package]] -name = "anstream" -version = "0.6.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d96bd03f33fe50a863e394ee9718a706f988b9079b20c3784fb726e7678b62fb" -dependencies = [ - "anstyle", - "anstyle-parse", - "anstyle-query", - "anstyle-wincon", - "colorchoice", - "utf8parse", -] - -[[package]] -name = "anstyle" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" - -[[package]] -name = "anstyle-parse" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" -dependencies = [ - "utf8parse", -] - -[[package]] -name = "anstyle-query" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" -dependencies = [ - "windows-sys", -] - -[[package]] -name = "anstyle-wincon" -version = "3.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" -dependencies = [ - "anstyle", - "windows-sys", -] - -[[package]] -name = "colorchoice" +name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" - -[[package]] -name = "env_filter" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea" -dependencies = [ - "log", - "regex", -] - -[[package]] -name = "env_logger" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9" -dependencies = [ - "anstream", - "anstyle", - "env_filter", - "humantime", - "log", -] +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "extensor" version = "0.1.1" dependencies = [ - "env_logger", "extensor-macros", - "log", + "rstest", ] [[package]] @@ -107,46 +35,40 @@ dependencies = [ ] [[package]] -name = "humantime" -version = "2.1.0" +name = "glob" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" - -[[package]] -name = "log" -version = "0.4.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "memchr" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "proc-macro2" -version = "1.0.81" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.36" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] [[package]] name = "regex" -version = "1.10.4" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -156,9 +78,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.6" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -167,102 +89,71 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] -name = "syn" -version = "2.0.60" +name = "relative-path" +version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" [[package]] -name = "unicode-ident" -version = "1.0.12" +name = "rstest" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "utf8parse" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" - -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +checksum = "03e905296805ab93e13c1ec3a03f4b6c4f35e9498a3d5fa96dc626d22c03cd89" dependencies = [ - "windows-targets", + "rstest_macros", + "rustc_version", ] [[package]] -name = "windows-targets" -version = "0.52.5" +name = "rstest_macros" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "ef0053bbffce09062bee4bcc499b0fbe7a57b879f1efe088d6d8d4c7adcdef9b" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "cfg-if", + "glob", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", ] [[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.5" +name = "rustc_version" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] [[package]] -name = "windows_x86_64_gnu" -version = "0.52.5" +name = "semver" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.5" +name = "syn" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] [[package]] -name = "windows_x86_64_msvc" -version = "0.52.5" +name = "unicode-ident" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" diff --git a/Cargo.toml b/Cargo.toml index 2bdee11..f89cbb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,5 @@ rust-version = "1.80.0" [dependencies] extensor-macros = { path = "macros/", version = "0.1.0" } - [dev-dependencies] -log = "0.4.21" -env_logger = "0.11.3" +rstest = { version = "0.24", default-features = false } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8061aaf..d0504af 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -37,7 +37,7 @@ pub fn multilinear_map_derive(input: TokenStream) -> TokenStream { // // let param_name = Ident::new(&format!("v_{}", i), ident.span()); let param_name = Ident::new(&format!("v_{}", i), Span::call_site()); - quote! { #param_name: V<#ident, F> } + quote! { #param_name: Vector<#ident, F> } }); let loop_indices: Vec<_> = (0..const_generics.len()) @@ -86,3 +86,133 @@ pub fn multilinear_map_derive(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } + +#[proc_macro] +pub fn tensor(input: TokenStream) -> TokenStream { + let n: usize = input.to_string().parse().expect("Expected a usize"); + + // Generate const generic parameters (N0, N1, ..., N{n-1}) + let const_params = (0..n).map(|i| { + let ident = Ident::new(&format!("N{}", i), proc_macro2::Span::call_site()); + quote! { const #ident: usize, } + }); + + // Collect const_params into a Vec to reuse it + let const_params_vec: Vec<_> = const_params.collect(); + + // Generate const generic parameters (N0, N1, ..., N{n-1}) + let constants = (0..n).map(|i| { + let ident = Ident::new(&format!("N{}", i), proc_macro2::Span::call_site()); + quote! { #ident, } + }); + + // Collect const_params into a Vec to reuse it + let constants_vec: Vec<_> = constants.collect(); + + // Generate the coefficients type (Vector>) + let coefficients_type = (0..n).rev().fold(quote! { F }, |acc, i| { + let ident = Ident::new(&format!("N{}", i), proc_macro2::Span::call_site()); + quote! { Vector<#ident, #acc> } + }); + + // Add multilinear_map implementation + let input_params = (0..n).map(|i| { + let param_name = Ident::new(&format!("v_{}", i), Span::call_site()); + let dim_name = Ident::new(&format!("N{}", i), Span::call_site()); + // Build the nested type for each parameter + let param_type = (i + 1..n).rev().fold(quote! { F }, |acc, j| { + let next_dim = Ident::new(&format!("N{}", j), Span::call_site()); + quote! { Vector<#next_dim, #acc> } + }); + quote! { #param_name: &Vector<#dim_name, #param_type> } + }); + + let loop_indices: Vec<_> = (0..n) + .map(|i| Ident::new(&format!("i_{}", i), Span::call_site())) + .collect(); + + // Build nested scalar products + let mut inner_computation = quote! { self.coefficients }; + for (i, _index) in loop_indices.iter().enumerate() { + let v_name = Ident::new(&format!("v_{}", i), Span::call_site()); + inner_computation = quote! { + #inner_computation.scalar_product(#v_name) + }; + } + + // Generate the struct definition and implementations + let expanded = quote! { + pub struct Tensor<#(#const_params_vec)* F> + where + F: ScalarProduct + Copy + Default, + F::Inner: Add + Default + Copy, + { + pub coefficients: #coefficients_type, + } + + impl<#(#const_params_vec)* F> Default for Tensor<#(#constants_vec)* F> + where + F: ScalarProduct + Copy + Default, + F::Inner: Add + Default + Copy, + { + fn default() -> Self { + Self { + coefficients: <#coefficients_type>::default(), + } + } + } + + impl<#(#const_params_vec)* F> core::fmt::Debug for Tensor<#(#constants_vec)* F> + where + F: ScalarProduct + Copy + Default + Debug, + F::Inner: Add + Default + Copy, + { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Tensor") + .field("coefficients", &self.coefficients) + .finish() + } + } + + impl<#(#const_params_vec)* F> core::ops::Add for Tensor<#(#constants_vec)* F> + where + F: ScalarProduct + Default + Copy + Add, + F::Inner: Add + Default + Copy, + { + type Output = Self; + + fn add(self, other: Self) -> Self::Output { + Self { + coefficients: self.coefficients + other.coefficients, + } + } + } + + impl<#(#const_params_vec)* F> core::ops::Mul for Tensor<#(#constants_vec)* F> + where + F: ScalarProduct + core::ops::Mul + Copy + Default + AddAssign, + F::Inner: Add + Default + Copy, + { + type Output = Self; + + fn mul(self, scalar: F) -> Self::Output { + Self { + coefficients: self.coefficients * scalar, + } + } + } + + // impl<#(#const_params_vec)* F> Tensor<#(#constants_vec)* F> + // where + // F: ScalarProduct + Default + Copy + AddAssign + Mul + Add, + // F::Inner: Add + Default + Copy, + // { + // pub fn multilinear_map(&self, #(#input_params),*) -> F { + // #inner_computation + // } + // } + }; + + expanded.into() +} diff --git a/src/coproduct.rs b/src/coproduct.rs deleted file mode 100644 index 8ceaa5a..0000000 --- a/src/coproduct.rs +++ /dev/null @@ -1,93 +0,0 @@ -use super::*; - -pub trait Coproduct -where - Self: Sized, -{ - type X; - type Y; - - fn construct(x: Option, y: Option) -> Self; - - #[allow(non_snake_case)] - fn iota_X(x: Option) -> Self { - Self::construct(x, None) - } - - #[allow(non_snake_case)] - fn iota_Y(y: Option) -> Self { - Self::construct(None, y) - } - - #[allow(non_snake_case)] - fn get_X_via_tag(&self) -> Option; - - #[allow(non_snake_case)] - fn get_Y_via_tag(&self) -> Option; - - #[allow(non_snake_case)] - fn f>( - &self, - f_X: impl Fn(Option) -> Z, - f_Y: impl Fn(Option) -> Z, - ) -> Z { - f_X(self.get_X_via_tag()) + f_Y(self.get_Y_via_tag()) - } -} - -pub struct DirectSum { - v: Option>, - w: Option>, -} - -impl Coproduct for DirectSum -where - F: Copy, -{ - type X = V; - type Y = V; - - fn construct(v: Option, w: Option) -> Self { - assert!(v.is_some() || w.is_some()); - DirectSum { v, w } - } - - fn get_X_via_tag(&self) -> Option { - self.v - } - - fn get_Y_via_tag(&self) -> Option { - self.w - } -} - -impl Add for DirectSum -where - F: Add + Default + Copy, -{ - type Output = Self; - fn add(self, other: DirectSum) -> Self::Output { - DirectSum::construct( - self.v - .zip(other.v) - .map(|(v, other_v)| v + other_v) - .or(self.v) - .or(other.v), - self.w - .zip(other.w) - .map(|(w, other_w)| w + other_w) - .or(self.w) - .or(other.w), - ) - } -} - -impl Mul for DirectSum -where - F: Mul + Default + Copy, -{ - type Output = Self; - fn mul(self, scalar: F) -> Self::Output { - DirectSum::construct(self.v.map(|v| v * scalar), self.w.map(|w| w * scalar)) - } -} diff --git a/src/lib.rs b/src/lib.rs index ea73703..852f213 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,72 +1,16 @@ -#![allow(unstable_features)] -#![allow(incomplete_features)] -#![feature(generic_const_exprs)] #![no_std] +#![feature(generic_const_exprs)] use core::{ - fmt::{Debug, Formatter, Result}, + fmt::Debug, ops::{Add, Mul}, }; -use coproduct::{Coproduct, DirectSum}; -use product::{DirectProduct, ProductType}; - -pub mod coproduct; -pub mod product; +pub mod module; pub mod tensor; -pub mod unique_coproduct; - -#[derive(Copy, Clone, Debug)] -pub struct V([F; M]); - -impl Default for V -where - F: Default + Copy, -{ - fn default() -> Self { - V([F::default(); M]) - } -} - -impl + Default + Copy> Add for V { - type Output = Self; - fn add(self, other: V) -> Self::Output { - let mut sum = V::default(); - for i in 0..M { - sum.0[i] = self.0[i] + other.0[i]; - } - sum - } -} - -impl + Default + Copy> Mul for V { - type Output = Self; - fn mul(self, scalar: F) -> Self::Output { - let mut scalar_multiple = V::default(); - for i in 0..M { - scalar_multiple.0[i] = scalar * self.0[i]; - } - scalar_multiple - } -} -impl From> for DirectProduct -where - F: Add + Default + Copy, -{ - fn from(sum: DirectSum) -> DirectProduct { - DirectProduct::construct( - sum.get_X_via_tag().unwrap_or_default(), - sum.get_Y_via_tag().unwrap_or_default(), - ) - } -} +pub use extensor_macros::{tensor, MultilinearMap}; -impl From> for DirectSum -where - F: Add + Default + Copy, -{ - fn from(prod: DirectProduct) -> DirectSum { - DirectSum::iota_X(Some(prod.pi_X())) + DirectSum::iota_Y(Some(prod.pi_Y())) - } -} +#[cfg(test)] +#[macro_use] +extern crate std; diff --git a/src/module.rs b/src/module.rs new file mode 100644 index 0000000..2438a3f --- /dev/null +++ b/src/module.rs @@ -0,0 +1,376 @@ +use core::{ + mem::MaybeUninit, + ops::{Div, Neg}, +}; + +use super::*; + +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub struct Scalar(pub F); + +impl> Add for Scalar { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self(self.0 + rhs.0) + } +} + +impl> Mul for Scalar { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Self(self.0 * rhs.0) + } +} + +impl> Neg for Scalar { + type Output = Self; + + fn neg(self) -> Self::Output { + Self(-self.0) + } +} + +impl> Div for Scalar { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + Self(self.0 / rhs.0) + } +} + +pub trait Module: + Add + Neg + Mul + Copy +{ + type Ring: Add + Neg + Mul + Default + Copy; +} + +pub trait VectorSpace: Module +where + Self::Ring: Div, +{ +} + +#[derive(Copy, Clone)] +pub struct Vector(pub MaybeUninit<[F; M]>); + +impl Vector { + pub const fn new(arr: [F; M]) -> Self { + if M == 0 { + Self(MaybeUninit::uninit()) + } else { + Self(MaybeUninit::new(arr)) + } + } +} + +pub trait ScalarProduct { + type Inner; + fn scalar_product(&self, rhs: &Self) -> Self::Inner; +} + +impl ScalarProduct for Vector +where + T: ScalarProduct + Copy + Default, + T::Inner: Add + Default, +{ + type Inner = T::Inner; + + fn scalar_product(&self, rhs: &Self) -> Self::Inner { + if M == 0 { + T::Inner::default() + } else { + let mut scalar_product = T::Inner::default(); + let self_arr = unsafe { self.0.assume_init_ref() }; + let rhs_arr = unsafe { rhs.0.assume_init_ref() }; + + for i in 0..M { + scalar_product = scalar_product + self_arr[i].scalar_product(&rhs_arr[i]); + } + scalar_product + } + } +} + +// Base case - for scalar values +impl> ScalarProduct for Scalar { + type Inner = Self; + + fn scalar_product(&self, rhs: &Self) -> Self::Inner { + *self * *rhs // Use the Mul implementation for Scalar + } +} + +impl Debug for Vector { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + if M == 0 { + write!(f, "Vector([])") + } else { + let arr = unsafe { self.0.assume_init_ref() }; + write!(f, "Vector({arr:?})") + } + } +} + +// TODO: This could be const. +impl Default for Vector +where + F: Default + Copy, +{ + fn default() -> Self { + if M == 0 { + Self(MaybeUninit::uninit()) + } else { + Self(MaybeUninit::new([F::default(); M])) + } + } +} + +// TODO: Is it possible to have a compile-time known branch here so we don't +// check the const conditional `M==0` at runtime? +impl Add for Vector +where + F: Add + Default + Copy, +{ + type Output = Self; + fn add(self, other: Self) -> Self::Output { + if M == 0 { + Self(MaybeUninit::uninit()) + } else { + let mut sum = Self::default(); + let self_arr = unsafe { self.0.assume_init_ref() }; + let other_arr = unsafe { other.0.assume_init_ref() }; + let sum_arr = unsafe { sum.0.assume_init_mut() }; + + for i in 0..M { + sum_arr[i] = self_arr[i] + other_arr[i]; + } + sum + } + } +} + +// TODO: Is it possible to have a compile-time known branch here so we don't +// check the const conditional `M==0` at runtime? +impl Neg for Vector +where + F: Neg + Default + Copy, +{ + type Output = Self; + fn neg(self) -> Self::Output { + if M == 0 { + Self(MaybeUninit::uninit()) + } else { + let mut neg = Self::default(); + let self_arr = unsafe { self.0.assume_init_ref() }; + let neg_arr = unsafe { neg.0.assume_init_mut() }; + + for i in 0..M { + neg_arr[i] = -self_arr[i]; + } + neg + } + } +} + +impl Mul for Vector +where + F: Mul + Copy + Default, + Inner: Mul + Default + Copy, +{ + type Output = Self; + + fn mul(self, scalar: F) -> Self::Output { + if M == 0 { + Self(MaybeUninit::uninit()) + } else { + let mut scalar_multiple = Self::default(); + let self_arr = unsafe { self.0.assume_init_ref() }; + let scalar_multiple_arr = unsafe { scalar_multiple.0.assume_init_mut() }; + + for i in 0..M { + scalar_multiple_arr[i] = self_arr[i] * scalar; + } + scalar_multiple + } + } +} + +impl Module for Vector +where + F: Add + Neg + Mul + Default + Copy, +{ + type Ring = F; +} + +impl VectorSpace for Vector where + F: Add + Neg + Mul + Div + Default + Copy +{ +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use super::*; + + #[test] + fn test_zero_dimensional_operations() { + let nil = Vector::<0, Scalar>::default(); + let nil2 = Vector::<0, Scalar>::default(); + + // Test all operations + let sum = nil + nil2; + let neg = -nil; + let scaled = nil * Scalar(2.0); + + assert_eq!(format!("{:?}", sum), "Vector([])"); + assert_eq!(format!("{:?}", neg), "Vector([])"); + assert_eq!(format!("{:?}", scaled), "Vector([])"); + } + + #[rstest] + #[case::dim_1([Scalar(1.0)], [Scalar(2.0)], [Scalar(3.0)])] + #[case::dim_2([Scalar(1.0), Scalar(2.0)], [Scalar(3.0), Scalar(4.0)], [Scalar(4.0), Scalar(6.0)])] + #[case::dim_3([Scalar(1.0), Scalar(2.0), Scalar(3.0)], [Scalar(4.0), Scalar(5.0), Scalar(6.0)], [Scalar(5.0), Scalar(7.0), Scalar(9.0)])] + fn test_vector_addition( + #[case] a: [Scalar; M], + #[case] b: [Scalar; M], + #[case] expected: [Scalar; M], + ) where + [(); M - 1]:, + { + let va = Vector::new(a); + let vb = Vector::new(b); + let sum = va + vb; + assert_eq!(unsafe { sum.0.assume_init_ref() }, &expected); + } + + #[rstest] + #[case::dim_1([Scalar(1.0)], [Scalar(-1.0)])] + #[case::dim_2([Scalar(1.0), Scalar(2.0)], [Scalar(-1.0), Scalar(-2.0)])] + #[case::dim_3([Scalar(1.0), Scalar(2.0), Scalar(3.0)], [Scalar(-1.0), Scalar(-2.0), Scalar(-3.0)])] + fn test_vector_negation( + #[case] input: [Scalar; M], + #[case] expected: [Scalar; M], + ) where + [(); M - 1]:, + { + let v = Vector::new(input); + let neg = -v; + assert_eq!(unsafe { neg.0.assume_init_ref() }, &expected); + } + + #[rstest] + #[case::dim_1([Scalar(1.0)], Scalar(2.0), [Scalar(2.0)])] + #[case::dim_2([Scalar(1.0), Scalar(2.0)], Scalar(3.0), [Scalar(3.0), Scalar(6.0)])] + #[case::dim_3([Scalar(1.0), Scalar(2.0), Scalar(3.0)], Scalar(2.0), [Scalar(2.0), Scalar(4.0), Scalar(6.0)])] + fn test_vector_scalar_multiplication( + #[case] input: [Scalar; M], + #[case] scalar: Scalar, + #[case] expected: [Scalar; M], + ) where + [(); M - 1]:, + { + let v = Vector::new(input); + let scaled = v * scalar; + assert_eq!(unsafe { scaled.0.assume_init_ref() }, &expected); + } + + #[test] + fn test_vector_default() { + let v0: Vector<0, Scalar> = Vector::default(); + let v1: Vector<1, Scalar> = Vector::default(); + let v3: Vector<3, Scalar> = Vector::default(); + + assert_eq!(format!("{:?}", v0), "Vector([])"); + assert_eq!(unsafe { v1.0.assume_init_ref() }, &[Scalar(0.0)]); + assert_eq!( + unsafe { v3.0.assume_init_ref() }, + &[Scalar(0.0), Scalar(0.0), Scalar(0.0)] + ); + } + + #[test] + fn test_debug_formatting() { + let v0 = Vector::<0, Scalar>::default(); + let v1 = Vector::new([Scalar(1.0)]); + let v2 = Vector::new([Scalar(1.0), Scalar(2.0)]); + + assert_eq!(format!("{:?}", v0), "Vector([])"); + assert_eq!(format!("{:?}", v1), "Vector([Scalar(1.0)])"); + assert_eq!(format!("{:?}", v2), "Vector([Scalar(1.0), Scalar(2.0)])"); + } + + #[test] + fn test_vector_traits() { + fn assert_module() {} + fn assert_vector_space() + where + T::Ring: Div, + { + } + + // Test that Vector implements Module and VectorSpace for f64 + assert_module::>>(); + assert_module::>>(); + assert_vector_space::>>(); + assert_vector_space::>>(); + } + + #[test] + fn test_copy_clone() { + let v0 = Vector::<0, Scalar>::default(); + let v1 = Vector::new([Scalar(1.0)]); + + let v0_copied = v0; + let v1_copied = v1; + + // Test that copies work correctly + assert_eq!(format!("{:?}", v0), format!("{:?}", v0_copied)); + assert_eq!(unsafe { v1.0.assume_init_ref() }, unsafe { + v1_copied.0.assume_init_ref() + }); + } + + #[test] + fn test_matrix_scalar_product() { + // Create two 2x2 matrices as Vector<2, Vector<2, Scalar>> + let m1 = Vector::new([ + Vector::new([Scalar(1.0), Scalar(2.0)]), + Vector::new([Scalar(3.0), Scalar(4.0)]), + ]); + + let m2 = Vector::new([ + Vector::new([Scalar(5.0), Scalar(6.0)]), + Vector::new([Scalar(7.0), Scalar(8.0)]), + ]); + + // The scalar product should be: + // (1*5 + 2*6) + (3*7 + 4*8) = (5 + 12) + (21 + 32) = 17 + 53 = 70 + let result: Scalar = m1.scalar_product(&m2); + assert_eq!(result, Scalar(70.0)); + } + + #[test] + fn test_nested_vector_operations() { + // Test creation and scalar product of vectors of different sizes + let v1 = Vector::new([ + Vector::new([Scalar(1.0), Scalar(2.0), Scalar(3.0)]), + Vector::new([Scalar(4.0), Scalar(5.0), Scalar(6.0)]), + ]); + + let v2 = Vector::new([ + Vector::new([Scalar(7.0), Scalar(8.0), Scalar(9.0)]), + Vector::new([Scalar(10.0), Scalar(11.0), Scalar(12.0)]), + ]); + + // (1*7 + 2*8 + 3*9) + (4*10 + 5*11 + 6*12) + // = (7 + 16 + 27) + (40 + 55 + 72) + // = 50 + 167 + // = 217 + let result = v1.scalar_product(&v2); + assert_eq!(result, Scalar(217.0)); + } +} diff --git a/src/product.rs b/src/product.rs deleted file mode 100644 index f6e734a..0000000 --- a/src/product.rs +++ /dev/null @@ -1,68 +0,0 @@ -use super::*; - -pub trait ProductType -where - Self: Sized, -{ - type X; - type Y; - - fn construct(x: Self::X, y: Self::Y) -> Self; - - #[allow(non_snake_case)] - fn pi_X(&self) -> Self::X; - - #[allow(non_snake_case)] - fn pi_Y(&self) -> Self::Y; - - #[allow(non_snake_case)] - fn f(z: &Z, f_X: impl Fn(&Z) -> Self::X, f_Y: impl Fn(&Z) -> Self::Y) -> Self { - Self::construct(f_X(z), f_Y(z)) - } -} - -#[derive(Copy, Clone)] -pub struct DirectProduct { - v: V, - w: V, -} - -impl ProductType for DirectProduct -where - F: Copy, -{ - type X = V; - type Y = V; - - fn construct(v: Self::X, w: Self::Y) -> Self { - DirectProduct { v, w } - } - - fn pi_X(&self) -> Self::X { - self.v - } - - fn pi_Y(&self) -> Self::Y { - self.w - } -} - -impl Add for DirectProduct -where - F: Add + Default + Copy, -{ - type Output = Self; - fn add(self, other: DirectProduct) -> Self::Output { - DirectProduct::construct(self.pi_X() + other.pi_X(), self.pi_Y() + other.pi_Y()) - } -} - -impl Mul for DirectProduct -where - F: Mul + Default + Copy, -{ - type Output = Self; - fn mul(self, scalar: F) -> Self::Output { - DirectProduct::construct(self.pi_X() * scalar, self.pi_Y() * scalar) - } -} diff --git a/src/tensor/macros.rs b/src/tensor/macros.rs index ed2ea32..40ed70c 100644 --- a/src/tensor/macros.rs +++ b/src/tensor/macros.rs @@ -1,205 +1,231 @@ +use module::{ScalarProduct, Vector}; + use super::*; -// TODO: Could probably just assign a valence to the tensors and use N0, N1, N2, -// etc. as dims - -#[macro_export] -macro_rules! tensor { - ($name:ident, $($consts:ident),+) => { - #[derive(extensor_macros::MultilinearMap)] - pub struct $name<$(const $consts: usize),+, F> - where F: Default + Copy + AddAssign + Mul, - { - pub coefficients: coeff_builder!($($consts),+; F), - } - - impl<$(const $consts: usize),+, F: Default + Copy + AddAssign + Mul> Default for $name<$($consts),+, F> { - fn default() -> Self { - let coefficients = ::default(); - $name { coefficients } - } - - } - - impl<$(const $consts: usize),+, F> Debug for $name<$($consts),+, F> - where - F: Default + Copy + Debug + AddAssign + Mul, - { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { - f.debug_struct(stringify!($name)) - .field("coefficients", &self.coefficients) - .finish() - } - } - - impl<$(const $consts: usize),+, F> Add for $name<$($consts),+, F> - where - F: Add + Copy + Default + AddAssign + Mul, - { - type Output = Self; - - fn add(self, other: Self) -> Self::Output { - let mut result = Self::default(); - add_tensors!(result.coefficients, self.coefficients, other.coefficients; $($consts),+); - result - } - } - - impl<$(const $consts: usize),+, F> Mul for $name<$($consts),+, F> - where - F: Mul + Copy + Default + AddAssign, - { - type Output = Self; - - fn mul(self, scalar: F) -> Self::Output { - let mut result = Self::default(); - scalar_mul_tensor!(result.coefficients, self.coefficients, scalar; $($consts),+); - result - } - } - } -} - -macro_rules! coeff_builder { - ($const:ident; $expr:ty) => { - V<$const, $expr> - }; - ($const:ident, $($rest:ident),+; $expr:ty) => { - V<$const, coeff_builder!($($rest),+; $expr)> - }; -} - -macro_rules! def_builder { - ($const:ident; $expr:ty) => { - V::<$const, $expr> - }; - ($const:ident, $($rest:ident),+; $expr:ty) => { - V::<$const, def_builder!($($rest),+; $expr)> - }; -} - -macro_rules! add_tensors { - ($result:expr, $self:expr, $other:expr; $const:ident) => { - for i in 0..$const { - $result.0[i] = $self.0[i] + $other.0[i]; - } - }; - ($result:expr, $self:expr, $other:expr; $const:ident, $($rest:ident),+) => { - for i in 0..$const { - add_tensors!($result.0[i], $self.0[i], $other.0[i]; $($rest),+); - } - }; -} - -macro_rules! scalar_mul_tensor { - ($result:expr, $self:expr, $scalar:expr; $const:ident) => { - for i in 0..$const { - $result.0[i] = $self.0[i] * $scalar; - } - }; - ($result:expr, $self:expr, $scalar:expr; $const:ident, $($rest:ident),+) => { - for i in 0..$const { - scalar_mul_tensor!($result.0[i], $self.0[i], $scalar; $($rest),+); - } - }; -} - -tensor!(TensorTester, M, N, P); - -#[cfg(test)] -mod tests { - - use super::*; - tensor!(Tensor2, M, N); - - tensor!(Tensor3, M, N, P); - - use log::{debug, info}; - - fn log() { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("trace")).init(); - } - - #[test] - fn create_arbitrary_tensor() { - // log(); - let tensor = Tensor2::<2, 3, f64>::default(); - debug!("{:?}", tensor.coefficients); - - let tensor = Tensor3::<2, 3, 4, f64>::default(); - debug!("{:?}", tensor.coefficients); - } - - #[test] - fn add_tensors() { - // log(); - let mut tensor1 = Tensor2::<2, 3, f64>::default(); - for i in 0..2 { - for j in 0..3 { - tensor1.coefficients.0[i].0[j] = (i + j) as f64; - } - } - debug!("tensor1: {:?}", tensor1.coefficients); - let mut tensor2 = Tensor2::<2, 3, f64>::default(); - for i in 0..2 { - for j in 0..3 { - tensor2.coefficients.0[i].0[j] = i as f64 - j as f64; - } - } - debug!("tensor2: {:?}", tensor2.coefficients); - let tensor3 = tensor1 + tensor2; - info!("output: {:?}", tensor3.coefficients); - } - - #[test] - fn scalar_mul_tensor() { - // log(); - let mut tensor1 = Tensor2::<2, 3, f64>::default(); - for i in 0..2 { - for j in 0..3 { - tensor1.coefficients.0[i].0[j] = (i + j) as f64; - } - } - debug!("tensor1: {:?}", tensor1.coefficients); - let scalar = 2.0; - let tensor2 = tensor1 * scalar; - info!("output: {:?}", tensor2.coefficients); - } - - #[test] - fn multilinear_map() { - log(); - // / 1 0 0 \ - // tensor = \ 0 1 0 / - let mut tensor = Tensor2::<2, 3, f64>::default(); - tensor.coefficients.0[0].0[0] = 1.0; - tensor.coefficients.0[1].0[1] = 1.0; - debug!("tensor: {:?}", tensor); - - // / -1 \ - // v_0 = \ 1 / - let mut v_0 = V::default(); - v_0.0[0] = -1.0; - v_0.0[1] = 1.0; - debug!("v_0: {:?}", v_0); - - // / 1 \ - // | 2 | - // v_1 = \ 3 / - let mut v_1 = V::default(); - v_1.0[0] = 1.0; - v_1.0[1] = 2.0; - v_1.0[2] = 3.0; - debug!("v_1: {:?}", v_1); - - // / 1 \ - // tensor.map(_,v_1) = \ 2 / - // - // then the next is: - // / 1 \ - // tensor.map(v_0, v_1) = < -1 1 > \ 2 / = -1 + 2 = 1 - let output = tensor.multilinear_map(v_0, v_1); - info!("output: {:?}", output); - assert_eq!(output, 1.0); - } -} +tensor!(3); + +// impl< +// const N0: usize, +// const N1: usize, +// const N2: usize, +// F: ScalarProduct +// + Default +// + Copy +// + AddAssign +// + Mul +// + core::ops::Add, +// > Tensor +// where +// F::Inner: Add + Default + Copy, +// { +// pub fn contract( +// &self, +// vector: Vector, +// ) -> Tensor< +// { (1 - (SLICE == 0) as usize) * N0 }, +// { (1 - (SLICE == 1) as usize) * N1 }, +// { (1 - (SLICE == 2) as usize) * N2 }, +// F, +// > +// where +// [(); (SLICE < 3) as usize - 1]:, +// [(); (SLICE == 0) as usize * ((DIM == N0) as usize) +// + (SLICE == 1) as usize * ((DIM == N1) as usize) +// + (SLICE == 2) as usize * ((DIM == N2) as usize) +// - 1]:, +// { +// let mut result = Tensor::default(); + +// match SLICE { +// 0 => { +// for i1 in 0..N1 { +// for i2 in 0..N2 { +// let mut sum = F::default(); +// for i0 in 0..N0 { +// sum = sum + self.coefficients.0[i0].0[i1].0[i2] * +// vector.0[i0]; } +// result.coefficients.0[0].0[i1].0[i2] = sum; +// } +// } +// } +// 1 => { +// for i0 in 0..N0 { +// for i2 in 0..N2 { +// let mut sum = F::default(); +// for i1 in 0..N1 { +// sum = sum + self.coefficients.0[i0].0[i1].0[i2] * +// vector.0[i1]; } +// result.coefficients.0[i0].0[0].0[i2] = sum; +// } +// } +// } +// 2 => { +// for i0 in 0..N0 { +// for i1 in 0..N1 { +// let mut sum = F::default(); +// for i2 in 0..N2 { +// sum = sum + self.coefficients.0[i0].0[i1].0[i2] * +// vector.0[i2]; } +// result.coefficients.0[i0].0[i1].0[0] = sum; +// } +// } +// } +// _ => unreachable!(), // Our where clause ensures this +// } + +// result +// } +// } + +// #[cfg(test)] +// mod tests { + +// use super::*; + +// // tensor!(2); + +// // #[test] +// // fn create_arbitrary_tensor() { +// // let tensor = Tensor::<2, 3, f64>::default(); +// // dbg!(tensor); +// // } + +// // #[test] +// // fn add_tensors() { +// // // log(); +// // let mut tensor1 = Tensor::<2, 3, f64>::default(); +// // for i in 0..2 { +// // for j in 0..3 { +// // tensor1.coefficients.0[i].0[j] = (i + j) as f64; +// // } +// // } +// // dbg!(tensor1.coefficients); +// // let mut tensor2 = Tensor::<2, 3, f64>::default(); +// // for i in 0..2 { +// // for j in 0..3 { +// // tensor2.coefficients.0[i].0[j] = i as f64 - j as f64; +// // } +// // } +// // dbg!(tensor2.coefficients); +// // let tensor3 = tensor1 + tensor2; +// // dbg!(tensor3.coefficients); +// // } + +// // #[test] +// // fn scalar_mul_tensor() { +// // // log(); +// // let mut tensor1 = Tensor::<2, 3, f64>::default(); +// // for i in 0..2 { +// // for j in 0..3 { +// // tensor1.coefficients.0[i].0[j] = (i + j) as f64; +// // } +// // } +// // dbg!(tensor1.coefficients); +// // let scalar = 2.0; +// // let tensor2 = tensor1 * scalar; +// // dbg!(tensor2.coefficients); +// // } + +// // #[test] +// // fn multilinear_map() { +// // // / 1 0 0 \ +// // // tensor = \ 0 1 0 / +// // let mut tensor = Tensor::<2, 3, f64>::default(); +// // tensor.coefficients.0[0].0[0] = 1.0; +// // tensor.coefficients.0[1].0[1] = 1.0; +// // dbg!(&tensor); + +// // // / -1 \ +// // // v_0 = \ 1 / +// // let mut v_0 = Vector::<2, _>::default(); +// // v_0.0[0] = -1.0; +// // v_0.0[1] = 1.0; +// // dbg!(v_0); + +// // // / 1 \ +// // // | 2 | +// // // v_1 = \ 3 / +// // let mut v_1 = Vector::<3, _>::default(); +// // v_1.0[0] = 1.0; +// // v_1.0[1] = 2.0; +// // v_1.0[2] = 3.0; +// // dbg!(v_1); + +// // // / 1 \ +// // // tensor.map(_,v_1) = \ 2 / +// // // +// // // then the next is: +// // // / 1 \ +// // // tensor.map(v_0, v_1) = < -1 1 > \ 2 / = -1 + 2 = 1 +// // let output = tensor.multilinear_map(v_0, v_1); +// // dbg!(output); +// // assert_eq!(output, 1.0); +// // } + +// #[test] +// fn test_contraction() { +// let mut tensor = Tensor::<2, 3, 4, f64>::default(); + +// // Fill tensor with some values... + +// let v = Vector::<4, f64>::default(); +// // Contract along M dimension +// let contracted: Tensor<2, 3, 0, f64> = tensor.contract::<2, 4>(v); +// } + +// #[test] +// fn test_rank3_contraction() { +// // Create a 2x3x2 tensor +// let mut tensor = Tensor::<2, 3, 2, f64>::default(); + +// // Fill tensor with some known values +// // Using a simple pattern: tensor[i][j][k] = i + j + k +// for i in 0..2 { +// for j in 0..3 { +// for k in 0..2 { +// tensor.coefficients.0[i].0[j].0[k] = (i + j + k) as f64; +// } +// } +// } + +// // Test contraction along slice 0 (first dimension) +// let v0 = Vector([1.0, 2.0]); // 2-dimensional vector for N0 +// let contracted0: Tensor<0, 3, 2, f64> = tensor.contract::<0, 2>(v0); +// // Expected: contracted0[j][k] = sum_i(tensor[i][j][k] * v0[i]) + +// // Test contraction along slice 1 (second dimension) +// let v1 = Vector([1.0, 2.0, 3.0]); // 3-dimensional vector for N1 +// let contracted1: Tensor<2, 0, 2, f64> = tensor.contract::<1, 3>(v1); +// // Expected: contracted1[i][k] = sum_j(tensor[i][j][k] * v1[j]) +// dbg!(contracted1); + +// // Test contraction along slice 2 (third dimension) +// let v2 = Vector([1.0, 2.0]); // 2-dimensional vector for N2 +// let contracted2: Tensor<2, 3, 0, f64> = tensor.contract::<2, 2>(v2); +// // Expected: contracted2[i][j] = sum_k(tensor[i][j][k] * v2[k]) + +// // Verify specific values +// // Let's check one value from each contraction + +// // For slice 0: contracted0[1][1] should be +// // tensor[0][1][1] * v0[0] + tensor[1][1][1] * v0[1] +// // assert_eq!( +// // contracted0.coefficients.0[0].0[1].0[1], +// // (2.0 * 1.0 + 3.0 * 2.0) +// // ); + +// // // For slice 1: contracted1[1][1] should be +// // // tensor[1][0][1] * v1[0] + tensor[1][1][1] * v1[1] + +// // tensor[1][2][1] * v1[2] assert_eq!( +// // contracted1.coefficients.0[1].0[0].0[1], +// // (2.0 * 1.0 + 3.0 * 2.0 + 4.0 * 3.0) +// // ); + +// // // For slice 2: contracted2[1][2] should be +// // // tensor[1][2][0] * v2[0] + tensor[1][2][1] * v2[1] +// // assert_eq!( +// // contracted2.coefficients.0[1].0[2].0[0], +// // (3.0 * 1.0 + 4.0 * 2.0) +// // ); +// } +// } diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 44299e5..cb5bd6c 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -4,156 +4,156 @@ use super::*; pub mod macros; -pub struct Tensor -where - [(); M * N]:, -{ - /// This set up makes the first index into `coefficients` the "rows" and - /// second the "columns" - coefficients: V>, -} +// pub struct Tensor +// where +// [(); M * N]:, +// { +// /// This set up makes the first index into `coefficients` the "rows" and +// /// second the "columns" +// coefficients: V>, +// } -impl Default for Tensor -where - [(); M * N]:, - F: Default + Copy, -{ - fn default() -> Self { - let coefficients = V::>::default(); - Tensor { coefficients } - } -} +// impl Default for Tensor +// where +// [(); M * N]:, +// F: Default + Copy, +// { +// fn default() -> Self { +// let coefficients = V::>::default(); +// Tensor { coefficients } +// } +// } -impl Tensor -where - [(); M * N]:, - F: Mul + Default + Copy, -{ - pub fn tensor_product(v: [V; P], w: [V; P]) -> Tensor { - let mut tensor = Tensor::default(); +// impl Tensor +// where +// [(); M * N]:, +// F: Mul + Default + Copy, +// { +// pub fn tensor_product(v: [V; P], w: [V; P]) +// -> Tensor { let mut tensor = Tensor::default(); - for p in 0..P { - for i in 0..M { - for j in 0..N { - tensor.coefficients.0[i].0[j] = v[p].0[i] * w[p].0[j]; - } - } - } - tensor - } -} +// for p in 0..P { +// for i in 0..M { +// for j in 0..N { +// tensor.coefficients.0[i].0[j] = v[p].0[i] * w[p].0[j]; +// } +// } +// } +// tensor +// } +// } -impl Add for Tensor -where - [(); M * N]:, - F: Add + Copy + Default, -{ - type Output = Self; - fn add(self, other: Tensor) -> Self::Output { - let mut tensor = Tensor::default(); - for i in 0..M { - for j in 0..N { - tensor.coefficients.0[i].0[j] = - self.coefficients.0[i].0[j] + other.coefficients.0[i].0[j]; - } - } - tensor - } -} +// impl Add for Tensor +// where +// [(); M * N]:, +// F: Add + Copy + Default, +// { +// type Output = Self; +// fn add(self, other: Tensor) -> Self::Output { +// let mut tensor = Tensor::default(); +// for i in 0..M { +// for j in 0..N { +// tensor.coefficients.0[i].0[j] = +// self.coefficients.0[i].0[j] + +// other.coefficients.0[i].0[j]; } +// } +// tensor +// } +// } -impl Mul for Tensor -where - [(); M * N]:, - F: Mul + Default + Copy, -{ - type Output = Self; - fn mul(self, scalar: F) -> Self::Output { - let mut tensor = Tensor::default(); - for i in 0..M { - for j in 0..N { - tensor.coefficients.0[i].0[j] = self.coefficients.0[i].0[j] * scalar; - } - } - tensor - } -} +// impl Mul for Tensor +// where +// [(); M * N]:, +// F: Mul + Default + Copy, +// { +// type Output = Self; +// fn mul(self, scalar: F) -> Self::Output { +// let mut tensor = Tensor::default(); +// for i in 0..M { +// for j in 0..N { +// tensor.coefficients.0[i].0[j] = self.coefficients.0[i].0[j] * +// scalar; } +// } +// tensor +// } +// } -/// Below are more features of tensor that we can define for free! +// /// Below are more features of tensor that we can define for free! -impl Tensor -where - [(); M * N]:, - F: Add + Mul + AddAssign + Default + Copy, -{ - pub fn bilinear_map(&self, v: V, w: V) -> F { - let mut sum = F::default(); - for i in 0..M { - for j in 0..N { - sum += v.0[j] * self.coefficients.0[i].0[j] * w.0[i]; - } - } - sum - } +// impl Tensor +// where +// [(); M * N]:, +// F: Add + Mul + AddAssign + Default + Copy, +// { +// pub fn bilinear_map(&self, v: V, w: V) -> F { +// let mut sum = F::default(); +// for i in 0..M { +// for j in 0..N { +// sum += v.0[j] * self.coefficients.0[i].0[j] * w.0[i]; +// } +// } +// sum +// } - /// Here, for each choice of `w`, we get a distinct linear functional on `V` - /// that utilizes the tensor product. - #[allow(non_snake_case)] - pub fn get_functional_on_V(&self, w: V) -> impl Fn(V) -> F + '_ { - move |v| self.bilinear_map(v, w) - } +// /// Here, for each choice of `w`, we get a distinct linear functional on +// `V` /// that utilizes the tensor product. +// #[allow(non_snake_case)] +// pub fn get_functional_on_V(&self, w: V) -> impl Fn(V) -> F + +// '_ { move |v| self.bilinear_map(v, w) +// } - /// Here, for each choice of `v`, we get a distinct linear functional on `W` - /// that utilizes the tensor product. - #[allow(non_snake_case)] - pub fn get_functional_on_W(&self, v: V) -> impl Fn(V) -> F + '_ { - move |w| self.bilinear_map(v, w) - } +// /// Here, for each choice of `v`, we get a distinct linear functional on +// `W` /// that utilizes the tensor product. +// #[allow(non_snake_case)] +// pub fn get_functional_on_W(&self, v: V) -> impl Fn(V) -> F + +// '_ { move |w| self.bilinear_map(v, w) +// } - /// Matrix multiplication acting from the left :) - #[allow(non_snake_case)] - pub fn linear_map_V_to_W(&self, v: V) -> V { - let mut w = V([F::default(); N]); - for j in 0..N { - for i in 0..M { - w.0[j] += self.coefficients.0[i].0[j] * v.0[j]; - } - } - w - } +// /// Matrix multiplication acting from the left :) +// #[allow(non_snake_case)] +// pub fn linear_map_V_to_W(&self, v: V) -> V { +// let mut w = V([F::default(); N]); +// for j in 0..N { +// for i in 0..M { +// w.0[j] += self.coefficients.0[i].0[j] * v.0[j]; +// } +// } +// w +// } - /// Matrix multiplication acting from the right :) - #[allow(non_snake_case)] - pub fn linear_map_W_to_V(&self, w: V) -> V { - let mut v = V([F::default(); M]); - for j in 0..N { - for i in 0..M { - v.0[j] += self.coefficients.0[i].0[j] * w.0[i]; - } - } - v - } -} +// /// Matrix multiplication acting from the right :) +// #[allow(non_snake_case)] +// pub fn linear_map_W_to_V(&self, w: V) -> V { +// let mut v = V([F::default(); M]); +// for j in 0..N { +// for i in 0..M { +// v.0[j] += self.coefficients.0[i].0[j] * w.0[i]; +// } +// } +// v +// } +// } -/// This implementation makes `Tensor` an "Algebra" :) -/// In other words, we can multiply M x N matrices with N x P matrices to get an -/// M x P matrix. -impl Mul> for Tensor -where - [(); M * N]:, - [(); N * P]:, - F: Add + AddAssign + Mul + Default + Copy, -{ - type Output = Self; - fn mul(self, other: Tensor) -> Self::Output { - let mut product = Tensor::default(); - for i in 0..N { - for k in 0..P { - for j in 0..M { - product.coefficients.0[j].0[k] += - self.coefficients.0[i].0[j] * other.coefficients.0[j].0[k]; - } - } - } - product - } -} +// /// This implementation makes `Tensor` an "Algebra" :) +// /// In other words, we can multiply M x N matrices with N x P matrices to get +// an /// M x P matrix. +// impl Mul> +// for Tensor where +// [(); M * N]:, +// [(); N * P]:, +// F: Add + AddAssign + Mul + Default + Copy, +// { +// type Output = Self; +// fn mul(self, other: Tensor) -> Self::Output { +// let mut product = Tensor::default(); +// for i in 0..N { +// for k in 0..P { +// for j in 0..M { +// product.coefficients.0[j].0[k] += +// self.coefficients.0[i].0[j] * +// other.coefficients.0[j].0[k]; } +// } +// } +// product +// } +// } diff --git a/src/unique_coproduct.rs b/src/unique_coproduct.rs deleted file mode 100644 index 69dc666..0000000 --- a/src/unique_coproduct.rs +++ /dev/null @@ -1,33 +0,0 @@ -use super::*; - -pub enum UniqueDirectSum { - V(V), - W(V), -} - -impl Add for UniqueDirectSum -where - F: Add + Default + Copy, -{ - type Output = Self; - fn add(self, other: UniqueDirectSum) -> Self::Output { - match (self, other) { - (UniqueDirectSum::V(v), UniqueDirectSum::V(w)) => UniqueDirectSum::V(V::add(v, w)), - (UniqueDirectSum::W(v), UniqueDirectSum::W(w)) => UniqueDirectSum::W(V::add(v, w)), - _ => panic!("Cannot add V and W with Rust `UniqueDirectSum`!"), - } - } -} - -impl Mul for UniqueDirectSum -where - F: Mul + Default + Copy, -{ - type Output = Self; - fn mul(self, scalar: F) -> Self::Output { - match self { - UniqueDirectSum::V(v) => UniqueDirectSum::V(v * scalar), - UniqueDirectSum::W(w) => UniqueDirectSum::W(w * scalar), - } - } -}