From c4995f59719a9cc96b4747e1266178fb2ea7141e Mon Sep 17 00:00:00 2001 From: Jakob Kraus Date: Thu, 19 Sep 2024 17:14:15 +0200 Subject: [PATCH] feat: implement more operations --- Cargo.lock | 10 + crates/medmodels-core/Cargo.toml | 1 + .../src/medrecord/datatypes/attribute.rs | 223 ++- .../src/medrecord/datatypes/mod.rs | 75 +- .../src/medrecord/datatypes/value.rs | 107 +- .../medmodels-core/src/medrecord/graph/mod.rs | 31 +- crates/medmodels-core/src/medrecord/mod.rs | 73 +- .../src/medrecord/querying/attributes/mod.rs | 132 ++ .../medrecord/querying/attributes/operand.rs | 874 +++++++++++ .../querying/attributes/operation.rs | 1357 +++++++++++++++++ .../src/medrecord/querying/edges/mod.rs | 52 +- .../src/medrecord/querying/edges/operand.rs | 528 ++++++- .../src/medrecord/querying/edges/operation.rs | 592 ++++++- .../src/medrecord/querying/edges/values.rs | 1 - .../src/medrecord/querying/mod.rs | 3 + .../src/medrecord/querying/nodes/mod.rs | 62 +- .../src/medrecord/querying/nodes/operand.rs | 610 +++++++- .../src/medrecord/querying/nodes/operation.rs | 806 +++++++++- .../src/medrecord/querying/nodes/values.rs | 1 - .../src/medrecord/querying/traits.rs | 2 +- .../src/medrecord/querying/values/mod.rs | 183 ++- .../src/medrecord/querying/values/operand.rs | 601 ++++++-- .../medrecord/querying/values/operation.rs | 934 ++++++++++-- rustmodels/src/medrecord/mod.rs | 2 +- 24 files changed, 6801 insertions(+), 459 deletions(-) create mode 100644 crates/medmodels-core/src/medrecord/querying/attributes/mod.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/attributes/operand.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/attributes/operation.rs delete mode 100644 crates/medmodels-core/src/medrecord/querying/edges/values.rs delete mode 100644 crates/medmodels-core/src/medrecord/querying/nodes/values.rs diff --git a/Cargo.lock b/Cargo.lock index 45a17b3f..280c2399 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -433,6 +433,15 @@ version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" @@ -523,6 +532,7 @@ name = "medmodels-core" version = "0.1.2" dependencies = [ "chrono", + "itertools", "medmodels-utils", "polars", "roaring", diff --git a/crates/medmodels-core/Cargo.toml b/crates/medmodels-core/Cargo.toml index 8b515b3b..48225587 100644 --- a/crates/medmodels-core/Cargo.toml +++ b/crates/medmodels-core/Cargo.toml @@ -15,3 +15,4 @@ serde = { workspace = true } chrono = { workspace = true } ron = "0.8.1" roaring = "0.10.6" +itertools = "0.13.0" diff --git a/crates/medmodels-core/src/medrecord/datatypes/attribute.rs b/crates/medmodels-core/src/medrecord/datatypes/attribute.rs index f02f12d4..9ec71520 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/attribute.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/attribute.rs @@ -1,8 +1,16 @@ -use super::{Contains, EndsWith, MedRecordValue, StartsWith}; -use crate::errors::MedRecordError; +use super::{ + Abs, Contains, EndsWith, Lowercase, MedRecordValue, Mod, Pow, Slice, StartsWith, Trim, TrimEnd, + TrimStart, Uppercase, +}; +use crate::errors::{MedRecordError, MedRecordResult}; use medmodels_utils::implement_from_for_wrapper; use serde::{Deserialize, Serialize}; -use std::{cmp::Ordering, fmt::Display, hash::Hash}; +use std::{ + cmp::Ordering, + fmt::Display, + hash::Hash, + ops::{Add, Mul, Sub}, +}; #[derive(Debug, Clone, Serialize, Deserialize)] pub enum MedRecordAttribute { @@ -43,15 +51,6 @@ impl TryFrom for MedRecordAttribute { } } -impl Display for MedRecordAttribute { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::String(value) => write!(f, "{}", value), - Self::Int(value) => write!(f, "{}", value), - } - } -} - impl PartialEq for MedRecordAttribute { fn eq(&self, other: &Self) -> bool { match (self, other) { @@ -80,6 +79,130 @@ impl PartialOrd for MedRecordAttribute { } } +impl Display for MedRecordAttribute { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::String(value) => write!(f, "{}", value), + Self::Int(value) => write!(f, "{}", value), + } + } +} + +// TODO: Add tests +impl Add for MedRecordAttribute { + type Output = MedRecordResult; + + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => { + Ok(MedRecordAttribute::String(value + rhs.as_str())) + } + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value + rhs)) + } + } + } +} + +// TODO: Add tests +impl Sub for MedRecordAttribute { + type Output = MedRecordResult; + + fn sub(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), + ), + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value - rhs)) + } + } + } +} + +// TODO: Add tests +impl Mul for MedRecordAttribute { + type Output = MedRecordResult; + + fn mul(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot multiply {} by {}", value, rhs)), + ), + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot multiply {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot multiply {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value * rhs)) + } + } + } +} + +// TODO: Add tests +impl Pow for MedRecordAttribute { + fn pow(self, rhs: Self) -> MedRecordResult { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => { + Err(MedRecordError::AssertionError(format!( + "Cannot raise {} to the power of {}", + value, rhs + ))) + } + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => { + Err(MedRecordError::AssertionError(format!( + "Cannot raise {} to the power of {}", + value, rhs + ))) + } + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => { + Err(MedRecordError::AssertionError(format!( + "Cannot raise {} to the power of {}", + value, rhs + ))) + } + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value.pow(rhs as u32))) + } + } + } +} + +// TODO: Add tests +impl Mod for MedRecordAttribute { + fn r#mod(self, rhs: Self) -> MedRecordResult { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot mod {} by {}", value, rhs)), + ), + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot mod {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot mod {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value % rhs)) + } + } + } +} + impl StartsWith for MedRecordAttribute { fn starts_with(&self, other: &Self) -> bool { match (self, other) { @@ -137,6 +260,82 @@ impl Contains for MedRecordAttribute { } } +// TODO: Add tests +impl Slice for MedRecordAttribute { + fn slice(self, range: std::ops::Range) -> Self { + match self { + MedRecordAttribute::String(value) => value[range].into(), + MedRecordAttribute::Int(value) => value.to_string()[range].into(), + } + } +} + +// TODO: Add tests +impl Abs for MedRecordAttribute { + fn abs(self) -> Self { + match self { + MedRecordAttribute::Int(value) => MedRecordAttribute::Int(value.abs()), + _ => self, + } + } +} + +// TODO: Add tests +impl Trim for MedRecordAttribute { + fn trim(self) -> Self { + match self { + MedRecordAttribute::String(value) => { + MedRecordAttribute::String(value.trim().to_string()) + } + _ => self, + } + } +} + +// TODO: Add tests +impl TrimStart for MedRecordAttribute { + fn trim_start(self) -> Self { + match self { + MedRecordAttribute::String(value) => { + MedRecordAttribute::String(value.trim_start().to_string()) + } + _ => self, + } + } +} + +// TODO: Add tests +impl TrimEnd for MedRecordAttribute { + fn trim_end(self) -> Self { + match self { + MedRecordAttribute::String(value) => { + MedRecordAttribute::String(value.trim_end().to_string()) + } + _ => self, + } + } +} + +// TODO: Add tests +impl Lowercase for MedRecordAttribute { + fn lowercase(self) -> Self { + match self { + MedRecordAttribute::String(value) => MedRecordAttribute::String(value.to_lowercase()), + _ => self, + } + } +} + +// TODO: Add tests +impl Uppercase for MedRecordAttribute { + fn uppercase(self) -> Self { + match self { + MedRecordAttribute::String(value) => MedRecordAttribute::String(value.to_uppercase()), + _ => self, + } + } +} + #[cfg(test)] mod test { use super::MedRecordAttribute; diff --git a/crates/medmodels-core/src/medrecord/datatypes/mod.rs b/crates/medmodels-core/src/medrecord/datatypes/mod.rs index 001b51dc..ada0f6c0 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/mod.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/mod.rs @@ -1,10 +1,8 @@ -#![allow(dead_code)] -// TODO: Remove the above line after query engine is implemented - mod attribute; mod value; pub use self::{attribute::MedRecordAttribute, value::MedRecordValue}; +use super::EdgeIndex; use crate::errors::MedRecordError; use serde::{Deserialize, Serialize}; use std::{fmt::Display, ops::Range}; @@ -54,6 +52,24 @@ impl From<&MedRecordValue> for DataType { } } +impl From for DataType { + fn from(value: MedRecordAttribute) -> Self { + match value { + MedRecordAttribute::String(_) => DataType::String, + MedRecordAttribute::Int(_) => DataType::Int, + } + } +} + +impl From<&MedRecordAttribute> for DataType { + fn from(value: &MedRecordAttribute) -> Self { + match value { + MedRecordAttribute::String(_) => DataType::String, + MedRecordAttribute::Int(_) => DataType::Int, + } + } +} + impl PartialEq for DataType { fn eq(&self, other: &Self) -> bool { match (self, other) { @@ -129,28 +145,52 @@ impl DataType { } } -pub trait Pow: Sized { - fn pow(self, exp: Self) -> Result; -} - -pub trait Mod: Sized { - fn r#mod(self, other: Self) -> Result; -} - pub trait StartsWith { fn starts_with(&self, other: &Self) -> bool; } +// TODO: Add tests +impl StartsWith for EdgeIndex { + fn starts_with(&self, other: &Self) -> bool { + self.to_string().starts_with(&other.to_string()) + } +} + pub trait EndsWith { fn ends_with(&self, other: &Self) -> bool; } +// TODO: Add tests +impl EndsWith for EdgeIndex { + fn ends_with(&self, other: &Self) -> bool { + self.to_string().ends_with(&other.to_string()) + } +} + pub trait Contains { fn contains(&self, other: &Self) -> bool; } -pub trait PartialNeq: PartialEq { - fn neq(&self, other: &Self) -> bool; +// TODO: Add tests +impl Contains for EdgeIndex { + fn contains(&self, other: &Self) -> bool { + self.to_string().contains(&other.to_string()) + } +} + +pub trait Pow: Sized { + fn pow(self, exp: Self) -> Result; +} + +pub trait Mod: Sized { + fn r#mod(self, other: Self) -> Result; +} + +// TODO: Add tests +impl Mod for EdgeIndex { + fn r#mod(self, other: Self) -> Result { + Ok(self % other) + } } pub trait Round { @@ -197,15 +237,6 @@ pub trait Slice { fn slice(self, range: Range) -> Self; } -impl PartialNeq for T -where - T: PartialOrd, -{ - fn neq(&self, other: &Self) -> bool { - self != other - } -} - #[cfg(test)] mod test { use super::{DataType, MedRecordValue}; diff --git a/crates/medmodels-core/src/medrecord/datatypes/value.rs b/crates/medmodels-core/src/medrecord/datatypes/value.rs index 792d879d..e4b28e5b 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/value.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/value.rs @@ -3,7 +3,7 @@ use super::{ Trim, TrimEnd, TrimStart, Uppercase, }; use crate::errors::MedRecordError; -use chrono::NaiveDateTime; +use chrono::{DateTime, NaiveDateTime}; use medmodels_utils::implement_from_for_wrapper; use serde::{Deserialize, Serialize}; use std::{ @@ -210,9 +210,17 @@ impl Add for MedRecordValue { (MedRecordValue::DateTime(value), MedRecordValue::Bool(rhs)) => Err( MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), ), - (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => Err( - MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), - ), + (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => { + Ok(DateTime::from_timestamp( + value.and_utc().timestamp() + rhs.and_utc().timestamp(), + 0, + ) + .ok_or(MedRecordError::AssertionError( + "Invalid timestamp".to_string(), + ))? + .naive_utc() + .into()) + } (MedRecordValue::DateTime(value), MedRecordValue::Null) => Err( MedRecordError::AssertionError(format!("Cannot add None to {}", value)), ), @@ -327,9 +335,17 @@ impl Sub for MedRecordValue { (MedRecordValue::DateTime(value), MedRecordValue::Bool(rhs)) => Err( MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), ), - (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => Err( - MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), - ), + (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => { + Ok(DateTime::from_timestamp( + value.and_utc().timestamp() - rhs.and_utc().timestamp(), + 0, + ) + .ok_or(MedRecordError::AssertionError( + "Invalid timestamp".to_string(), + ))? + .naive_utc() + .into()) + } (MedRecordValue::DateTime(value), MedRecordValue::Null) => Err( MedRecordError::AssertionError(format!("Cannot subtract None from {}", value)), ), @@ -621,9 +637,17 @@ impl Div for MedRecordValue { (MedRecordValue::DateTime(value), MedRecordValue::String(other)) => Err( MedRecordError::AssertionError(format!("Cannot divide {} by {}", value, other)), ), - (MedRecordValue::DateTime(value), MedRecordValue::Int(other)) => Err( - MedRecordError::AssertionError(format!("Cannot divide {} by {}", value, other)), - ), + (MedRecordValue::DateTime(value), MedRecordValue::Int(other)) => { + Ok(DateTime::from_timestamp( + (value.and_utc().timestamp() as f64 / other as f64).floor() as i64, + 0, + ) + .ok_or(MedRecordError::AssertionError( + "Invalid timestamp".to_string(), + ))? + .naive_utc() + .into()) + } (MedRecordValue::DateTime(value), MedRecordValue::Float(other)) => Err( MedRecordError::AssertionError(format!("Cannot divide {} by {}", value, other)), ), @@ -1183,7 +1207,7 @@ mod test { Uppercase, }, }; - use chrono::NaiveDateTime; + use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime}; #[test] fn test_default() { @@ -1669,9 +1693,23 @@ mod test { (MedRecordValue::DateTime(NaiveDateTime::MIN) + MedRecordValue::Bool(false)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) - + MedRecordValue::DateTime(NaiveDateTime::MIN)) - .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); + assert_eq!( + MedRecordValue::DateTime( + NaiveDate::from_ymd_opt(1970, 1, 4) + .unwrap() + .and_time(NaiveTime::MIN) + ), + (MedRecordValue::DateTime( + NaiveDate::from_ymd_opt(1970, 1, 2) + .unwrap() + .and_time(NaiveTime::MIN) + ) + MedRecordValue::DateTime( + NaiveDate::from_ymd_opt(1970, 1, 3) + .unwrap() + .and_time(NaiveTime::MIN) + )) + .unwrap() + ); assert!( (MedRecordValue::DateTime(NaiveDateTime::MIN) + MedRecordValue::Null) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) @@ -1794,9 +1832,12 @@ mod test { (MedRecordValue::DateTime(NaiveDateTime::MIN) - MedRecordValue::Bool(false)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) - - MedRecordValue::DateTime(NaiveDateTime::MIN)) - .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); + assert_eq!( + MedRecordValue::DateTime(DateTime::from_timestamp(0, 0).unwrap().naive_utc()), + (MedRecordValue::DateTime(NaiveDateTime::MAX) + - MedRecordValue::DateTime(NaiveDateTime::MAX)) + .unwrap() + ); assert!( (MedRecordValue::DateTime(NaiveDateTime::MIN) - MedRecordValue::Null) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) @@ -1951,15 +1992,15 @@ mod test { / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( - (MedRecordValue::String("value".to_string()) / MedRecordValue::Int(0)) + (MedRecordValue::String("value".to_string()) / MedRecordValue::Int(1)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!( - (MedRecordValue::String("value".to_string()) / MedRecordValue::Float(0_f64)) + (MedRecordValue::String("value".to_string()) / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!( - (MedRecordValue::String("value".to_string()) / MedRecordValue::Bool(false)) + (MedRecordValue::String("value".to_string()) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!((MedRecordValue::String("value".to_string()) @@ -1982,7 +2023,7 @@ mod test { MedRecordValue::Float(1_f64), (MedRecordValue::Int(5) / MedRecordValue::Float(5_f64)).unwrap() ); - assert!((MedRecordValue::Int(0) / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Int(0) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Int(0) / MedRecordValue::DateTime(NaiveDateTime::MIN)) @@ -2003,7 +2044,7 @@ mod test { MedRecordValue::Float(1_f64), (MedRecordValue::Float(5_f64) / MedRecordValue::Float(5_f64)).unwrap() ); - assert!((MedRecordValue::Float(0_f64) / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Float(0_f64) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Float(0_f64) / MedRecordValue::DateTime(NaiveDateTime::MIN)) @@ -2016,11 +2057,11 @@ mod test { (MedRecordValue::Bool(false) / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::Bool(false) / MedRecordValue::Int(0)) + assert!((MedRecordValue::Bool(false) / MedRecordValue::Int(1)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Bool(false) / MedRecordValue::Float(0_f64)) + assert!((MedRecordValue::Bool(false) / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Bool(false) / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Bool(false) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Bool(false) / MedRecordValue::DateTime(NaiveDateTime::MIN)) @@ -2032,16 +2073,16 @@ mod test { assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!( - (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Int(0)) - .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) + assert_eq!( + MedRecordValue::DateTime(NaiveDateTime::MIN), + (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Int(1)).unwrap() ); assert!( - (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Float(0_f64)) + (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!( - (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Bool(false)) + (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) @@ -2056,11 +2097,11 @@ mod test { (MedRecordValue::Null / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::Null / MedRecordValue::Int(0)) + assert!((MedRecordValue::Null / MedRecordValue::Int(1)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Null / MedRecordValue::Float(0_f64)) + assert!((MedRecordValue::Null / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Null / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Null / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Null / MedRecordValue::DateTime(NaiveDateTime::MIN)) diff --git a/crates/medmodels-core/src/medrecord/graph/mod.rs b/crates/medmodels-core/src/medrecord/graph/mod.rs index 8905b90b..9d3ebb4f 100644 --- a/crates/medmodels-core/src/medrecord/graph/mod.rs +++ b/crates/medmodels-core/src/medrecord/graph/mod.rs @@ -338,7 +338,7 @@ impl Graph { self.edges.contains_key(edge_index) } - pub fn neighbors( + pub fn neighbors_outgoing( &self, node_index: &NodeIndex, ) -> Result, GraphError> { @@ -360,6 +360,29 @@ impl Graph { })) } + // TODO: Add tests + pub fn neighbors_incoming( + &self, + node_index: &NodeIndex, + ) -> Result, GraphError> { + Ok(self + .nodes + .get(node_index) + .ok_or(GraphError::IndexError(format!( + "Cannot find node with index {}", + node_index + )))? + .incoming_edge_indices + .iter() + .map(|edge_index| { + &self + .edges + .get(edge_index) + .expect("Edge must exist") + .source_node_index + })) + } + pub fn neighbors_undirected( &self, node_index: &NodeIndex, @@ -890,7 +913,7 @@ mod test { fn test_neighbors() { let graph = create_graph(); - let neighbors = graph.neighbors(&"0".into()).unwrap(); + let neighbors = graph.neighbors_outgoing(&"0".into()).unwrap(); assert_eq!(2, neighbors.count()); } @@ -900,7 +923,7 @@ mod test { let graph = create_graph(); assert!(graph - .neighbors(&"50".into()) + .neighbors_outgoing(&"50".into()) .is_err_and(|e| matches!(e, GraphError::IndexError(_)))); } @@ -908,7 +931,7 @@ mod test { fn test_neighbors_undirected() { let graph = create_graph(); - let neighbors = graph.neighbors(&"2".into()).unwrap(); + let neighbors = graph.neighbors_outgoing(&"2".into()).unwrap(); assert_eq!(0, neighbors.count()); let neighbors = graph.neighbors_undirected(&"2".into()).unwrap(); diff --git a/crates/medmodels-core/src/medrecord/mod.rs b/crates/medmodels-core/src/medrecord/mod.rs index 388b865e..aaea6808 100644 --- a/crates/medmodels-core/src/medrecord/mod.rs +++ b/crates/medmodels-core/src/medrecord/mod.rs @@ -683,12 +683,22 @@ impl MedRecord { self.group_mapping.contains_group(group) } - pub fn neighbors( + pub fn neighbors_outgoing( &self, node_index: &NodeIndex, ) -> Result, MedRecordError> { self.graph - .neighbors(node_index) + .neighbors_outgoing(node_index) + .map_err(MedRecordError::from) + } + + // TODO: Add tests + pub fn neighbors_incoming( + &self, + node_index: &NodeIndex, + ) -> Result, MedRecordError> { + self.graph + .neighbors_incoming(node_index) .map_err(MedRecordError::from) } @@ -1850,7 +1860,7 @@ mod test { fn test_neighbors() { let medrecord = create_medrecord(); - let neighbors = medrecord.neighbors(&"0".into()).unwrap(); + let neighbors = medrecord.neighbors_outgoing(&"0".into()).unwrap(); assert_eq!(2, neighbors.count()); } @@ -1861,7 +1871,7 @@ mod test { // Querying neighbors of a non-existing node sohuld fail assert!(medrecord - .neighbors(&"0".into()) + .neighbors_outgoing(&"0".into()) .is_err_and(|e| matches!(e, MedRecordError::IndexError(_)))); } @@ -1869,7 +1879,7 @@ mod test { fn test_neighbors_undirected() { let medrecord = create_medrecord(); - let neighbors = medrecord.neighbors(&"2".into()).unwrap(); + let neighbors = medrecord.neighbors_outgoing(&"2".into()).unwrap(); assert_eq!(0, neighbors.count()); let neighbors = medrecord.neighbors_undirected(&"2".into()).unwrap(); @@ -1895,57 +1905,4 @@ mod test { assert_eq!(0, medrecord.edge_count()); assert_eq!(0, medrecord.group_count()); } - - #[test] - fn test_test() { - let nodes = vec![ - ("0".into(), HashMap::from([("time".into(), 0.into())])), - ("1".into(), HashMap::from([("time".into(), 1.into())])), - ("2".into(), HashMap::from([("time".into(), 2.into())])), - ("3".into(), HashMap::from([("time".into(), 3.into())])), - ]; - - let edges = vec![ - ( - "0".into(), - "1".into(), - HashMap::from([("time".into(), 0.into())]), - ), - ( - "0".into(), - "1".into(), - HashMap::from([("time".into(), 2.into())]), - ), - ( - "0".into(), - "1".into(), - HashMap::from([("time".into(), 3.into())]), - ), - ( - "0".into(), - "1".into(), - HashMap::from([("time".into(), 4.into())]), - ), - ( - "0".into(), - "2".into(), - HashMap::from([("time".into(), 5.into())]), - ), - ]; - - let mut medrecord = MedRecord::from_tuples(nodes, Some(edges), None).unwrap(); - - medrecord - .add_group("treatment".into(), Some(vec!["1".into()]), None) - .unwrap(); - medrecord - .add_group("outcome".into(), Some(vec!["2".into()]), None) - .unwrap(); - - let nodes = medrecord.select_edges(|edge| { - edge.attribute("time").less_than(2); - }); - - println!("\n{:?}", nodes.collect::>()); - } } diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs b/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs new file mode 100644 index 00000000..d16fcabd --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs @@ -0,0 +1,132 @@ +mod operand; +mod operation; + +use super::{ + edges::{EdgeOperand, EdgeOperation}, + nodes::{NodeOperand, NodeOperation}, + BoxedIterator, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{Attributes, EdgeIndex, MedRecordAttribute, NodeIndex}, + MedRecord, +}; +pub use operand::{AttributesTreeOperand, MultipleAttributesOperand}; +pub use operation::{AttributesTreeOperation, MultipleAttributesOperation}; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum MultipleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} + +#[derive(Debug, Clone)] +pub enum UnaryArithmeticKind { + Abs, + Trim, + TrimStart, + TrimEnd, + Lowercase, + Uppercase, +} + +pub(crate) trait GetAttributes { + fn get_attributes<'a>(&'a self, medrecord: &'a MedRecord) -> MedRecordResult<&'a Attributes>; +} + +impl GetAttributes for NodeIndex { + fn get_attributes<'a>(&'a self, medrecord: &'a MedRecord) -> MedRecordResult<&'a Attributes> { + medrecord.node_attributes(self) + } +} + +impl GetAttributes for EdgeIndex { + fn get_attributes<'a>(&'a self, medrecord: &'a MedRecord) -> MedRecordResult<&'a Attributes> { + medrecord.edge_attributes(self) + } +} + +#[derive(Debug, Clone)] +pub enum Context { + NodeOperand(NodeOperand), + EdgeOperand(EdgeOperand), +} + +impl Context { + pub(crate) fn get_attributes<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult>> { + Ok(match self { + Self::NodeOperand(node_operand) => { + let node_indices = node_operand.evaluate(medrecord)?; + + Box::new( + NodeOperation::get_attributes(medrecord, node_indices).map(|(_, value)| value), + ) + } + Self::EdgeOperand(edge_operand) => { + let edge_indices = edge_operand.evaluate(medrecord)?; + + Box::new( + EdgeOperation::get_attributes(medrecord, edge_indices).map(|(_, value)| value), + ) + } + }) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/operand.rs b/crates/medmodels-core/src/medrecord/querying/attributes/operand.rs new file mode 100644 index 00000000..83af4393 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/attributes/operand.rs @@ -0,0 +1,874 @@ +use super::{ + operation::{AttributesTreeOperation, MultipleAttributesOperation, SingleAttributeOperation}, + BinaryArithmeticKind, Context, GetAttributes, MultipleComparisonKind, MultipleKind, + SingleComparisonKind, SingleKind, UnaryArithmeticKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + traits::{DeepClone, ReadWriteOrPanic}, + values::{self, MultipleValuesOperand}, + BoxedIterator, + }, + MedRecordAttribute, Wrapper, + }, + MedRecord, +}; +use std::{fmt::Display, hash::Hash}; + +macro_rules! implement_attributes_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = Wrapper::::new( + self.deep_clone(), + MultipleKind::$variant, + ); + + self.operations + .push(AttributesTreeOperation::AttributesOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_attribute_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = + Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(MultipleAttributesOperation::AttributeOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_attribute_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, attribute: V) { + self.operations + .push($operation::SingleAttributeComparisonOperation { + operand: attribute.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, attribute: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: attribute.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_unary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name(&mut self) { + self.operations.push($operation::UnaryArithmeticOperation { + kind: UnaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $attribute_type:ty) => { + pub fn $name(&self, attribute: $attribute_type) { + self.0.write_or_panic().$name(attribute) + } + }; +} + +#[derive(Debug, Clone)] +pub enum SingleAttributeComparisonOperand { + Operand(SingleAttributeOperand), + Attribute(MedRecordAttribute), +} + +impl DeepClone for SingleAttributeComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Attribute(attribute) => Self::Attribute(attribute.clone()), + } + } +} + +impl From> for SingleAttributeComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for SingleAttributeComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for SingleAttributeComparisonOperand { + fn from(value: V) -> Self { + Self::Attribute(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum MultipleAttributesComparisonOperand { + Operand(MultipleAttributesOperand), + Attributes(Vec), +} + +impl DeepClone for MultipleAttributesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Attributes(attribute) => Self::Attributes(attribute.clone()), + } + } +} + +impl From> for MultipleAttributesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for MultipleAttributesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for MultipleAttributesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Attributes(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> + for MultipleAttributesComparisonOperand +{ + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct AttributesTreeOperand { + pub(crate) context: Context, + operations: Vec, +} + +impl DeepClone for AttributesTreeOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl AttributesTreeOperand { + pub(crate) fn new(context: Context) -> Self { + Self { + context, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + ) -> MedRecordResult)>> { + let attributes = Box::new(attributes) as BoxedIterator<(&'a T, Vec)>; + + self.operations + .iter() + .try_fold(attributes, |attribute_tuples, operation| { + operation.evaluate(medrecord, attribute_tuples) + }) + } + + implement_attributes_operation!(max, Max); + implement_attributes_operation!(min, Min); + implement_attributes_operation!(count, Count); + implement_attributes_operation!(sum, Sum); + implement_attributes_operation!(first, First); + implement_attributes_operation!(last, Last); + + implement_single_attribute_comparison_operation!( + greater_than, + AttributesTreeOperation, + GreaterThan + ); + implement_single_attribute_comparison_operation!( + greater_than_or_equal_to, + AttributesTreeOperation, + GreaterThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(less_than, AttributesTreeOperation, LessThan); + implement_single_attribute_comparison_operation!( + less_than_or_equal_to, + AttributesTreeOperation, + LessThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(equal_to, AttributesTreeOperation, EqualTo); + implement_single_attribute_comparison_operation!( + not_equal_to, + AttributesTreeOperation, + NotEqualTo + ); + implement_single_attribute_comparison_operation!( + starts_with, + AttributesTreeOperation, + StartsWith + ); + implement_single_attribute_comparison_operation!(ends_with, AttributesTreeOperation, EndsWith); + implement_single_attribute_comparison_operation!(contains, AttributesTreeOperation, Contains); + + pub fn is_in>(&mut self, attributes: V) { + self.operations.push( + AttributesTreeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsIn, + }, + ); + } + + pub fn is_not_in>(&mut self, attributes: V) { + self.operations.push( + AttributesTreeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsNotIn, + }, + ); + } + + implement_binary_arithmetic_operation!(add, AttributesTreeOperation, Add); + implement_binary_arithmetic_operation!(sub, AttributesTreeOperation, Sub); + implement_binary_arithmetic_operation!(mul, AttributesTreeOperation, Mul); + implement_binary_arithmetic_operation!(pow, AttributesTreeOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, AttributesTreeOperation, Mod); + + implement_unary_arithmetic_operation!(abs, AttributesTreeOperation, Abs); + implement_unary_arithmetic_operation!(trim, AttributesTreeOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, AttributesTreeOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, AttributesTreeOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, AttributesTreeOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, AttributesTreeOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(AttributesTreeOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, AttributesTreeOperation::IsString); + implement_assertion_operation!(is_int, AttributesTreeOperation::IsInt); + implement_assertion_operation!(is_max, AttributesTreeOperation::IsMax); + implement_assertion_operation!(is_min, AttributesTreeOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(self.context.clone()); + let mut or_operand = Wrapper::::new(self.context.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(AttributesTreeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: Context) -> Self { + AttributesTreeOperand::new(context).into() + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + ) -> MedRecordResult)>> { + self.0.read_or_panic().evaluate(medrecord, attributes) + } + + implement_wrapper_operand_with_return!(max, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(min, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(count, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(sum, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(first, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(last, MultipleAttributesOperand); + + implement_wrapper_operand_with_argument!( + greater_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!( + not_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + starts_with, + impl Into + ); + implement_wrapper_operand_with_argument!( + ends_with, + impl Into + ); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!( + is_not_in, + impl Into + ); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct MultipleAttributesOperand { + pub(crate) context: AttributesTreeOperand, + pub(crate) kind: MultipleKind, + operations: Vec, +} + +impl DeepClone for MultipleAttributesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl MultipleAttributesOperand { + pub(crate) fn new(context: AttributesTreeOperand, kind: MultipleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + ) -> MedRecordResult> { + let attributes = Box::new(attributes) as BoxedIterator<(&'a T, MedRecordAttribute)>; + + self.operations + .iter() + .try_fold(attributes, |attribute_tuples, operation| { + operation.evaluate(medrecord, attribute_tuples) + }) + } + + implement_attribute_operation!(max, Max); + implement_attribute_operation!(min, Min); + implement_attribute_operation!(count, Count); + implement_attribute_operation!(sum, Sum); + implement_attribute_operation!(first, First); + implement_attribute_operation!(last, Last); + + implement_single_attribute_comparison_operation!( + greater_than, + MultipleAttributesOperation, + GreaterThan + ); + implement_single_attribute_comparison_operation!( + greater_than_or_equal_to, + MultipleAttributesOperation, + GreaterThanOrEqualTo + ); + implement_single_attribute_comparison_operation!( + less_than, + MultipleAttributesOperation, + LessThan + ); + implement_single_attribute_comparison_operation!( + less_than_or_equal_to, + MultipleAttributesOperation, + LessThanOrEqualTo + ); + implement_single_attribute_comparison_operation!( + equal_to, + MultipleAttributesOperation, + EqualTo + ); + implement_single_attribute_comparison_operation!( + not_equal_to, + MultipleAttributesOperation, + NotEqualTo + ); + implement_single_attribute_comparison_operation!( + starts_with, + MultipleAttributesOperation, + StartsWith + ); + implement_single_attribute_comparison_operation!( + ends_with, + MultipleAttributesOperation, + EndsWith + ); + implement_single_attribute_comparison_operation!( + contains, + MultipleAttributesOperation, + Contains + ); + + pub fn is_in>(&mut self, attributes: V) { + self.operations.push( + MultipleAttributesOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsIn, + }, + ); + } + + pub fn is_not_in>(&mut self, attributes: V) { + self.operations.push( + MultipleAttributesOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsNotIn, + }, + ); + } + + implement_binary_arithmetic_operation!(add, MultipleAttributesOperation, Add); + implement_binary_arithmetic_operation!(sub, MultipleAttributesOperation, Sub); + implement_binary_arithmetic_operation!(mul, MultipleAttributesOperation, Mul); + implement_binary_arithmetic_operation!(pow, MultipleAttributesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, MultipleAttributesOperation, Mod); + + implement_unary_arithmetic_operation!(abs, MultipleAttributesOperation, Abs); + implement_unary_arithmetic_operation!(trim, MultipleAttributesOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, MultipleAttributesOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, MultipleAttributesOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, MultipleAttributesOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, MultipleAttributesOperation, Uppercase); + + #[allow(clippy::wrong_self_convention)] + pub fn to_values(&mut self) -> Wrapper { + let operand = Wrapper::::new( + values::Context::MultipleAttributesOperand(self.deep_clone()), + "unused".into(), + ); + + self.operations.push(MultipleAttributesOperation::ToValues { + operand: operand.clone(), + }); + + operand + } + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(MultipleAttributesOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, MultipleAttributesOperation::IsString); + implement_assertion_operation!(is_int, MultipleAttributesOperation::IsInt); + implement_assertion_operation!(is_max, MultipleAttributesOperation::IsMax); + implement_assertion_operation!(is_min, MultipleAttributesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(MultipleAttributesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: AttributesTreeOperand, kind: MultipleKind) -> Self { + MultipleAttributesOperand::new(context, kind).into() + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, attributes) + } + + implement_wrapper_operand_with_return!(max, SingleAttributeOperand); + implement_wrapper_operand_with_return!(min, SingleAttributeOperand); + implement_wrapper_operand_with_return!(count, SingleAttributeOperand); + implement_wrapper_operand_with_return!(sum, SingleAttributeOperand); + implement_wrapper_operand_with_return!(first, SingleAttributeOperand); + implement_wrapper_operand_with_return!(last, SingleAttributeOperand); + + implement_wrapper_operand_with_argument!( + greater_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!( + not_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + starts_with, + impl Into + ); + implement_wrapper_operand_with_argument!( + ends_with, + impl Into + ); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!( + is_not_in, + impl Into + ); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + implement_wrapper_operand_with_return!(to_values, MultipleValuesOperand); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct SingleAttributeOperand { + pub(crate) context: MultipleAttributesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for SingleAttributeOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl SingleAttributeOperand { + pub(crate) fn new(context: MultipleAttributesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(attribute), |attribute, operation| { + if let Some(attribute) = attribute { + operation.evaluate(medrecord, attribute) + } else { + Ok(None) + } + }) + } + + implement_single_attribute_comparison_operation!( + greater_than, + SingleAttributeOperation, + GreaterThan + ); + implement_single_attribute_comparison_operation!( + greater_than_or_equal_to, + SingleAttributeOperation, + GreaterThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(less_than, SingleAttributeOperation, LessThan); + implement_single_attribute_comparison_operation!( + less_than_or_equal_to, + SingleAttributeOperation, + LessThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(equal_to, SingleAttributeOperation, EqualTo); + implement_single_attribute_comparison_operation!( + not_equal_to, + SingleAttributeOperation, + NotEqualTo + ); + implement_single_attribute_comparison_operation!( + starts_with, + SingleAttributeOperation, + StartsWith + ); + implement_single_attribute_comparison_operation!(ends_with, SingleAttributeOperation, EndsWith); + implement_single_attribute_comparison_operation!(contains, SingleAttributeOperation, Contains); + + pub fn is_in>(&mut self, attributes: V) { + self.operations.push( + SingleAttributeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsIn, + }, + ); + } + + pub fn is_not_in>(&mut self, attributes: V) { + self.operations.push( + SingleAttributeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsNotIn, + }, + ); + } + + implement_binary_arithmetic_operation!(add, SingleAttributeOperation, Add); + implement_binary_arithmetic_operation!(sub, SingleAttributeOperation, Sub); + implement_binary_arithmetic_operation!(mul, SingleAttributeOperation, Mul); + implement_binary_arithmetic_operation!(pow, SingleAttributeOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, SingleAttributeOperation, Mod); + + implement_unary_arithmetic_operation!(abs, SingleAttributeOperation, Abs); + implement_unary_arithmetic_operation!(trim, SingleAttributeOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, SingleAttributeOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, SingleAttributeOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, SingleAttributeOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, SingleAttributeOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(SingleAttributeOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, SingleAttributeOperation::IsString); + implement_assertion_operation!(is_int, SingleAttributeOperation::IsInt); + + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(SingleAttributeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: MultipleAttributesOperand, kind: SingleKind) -> Self { + SingleAttributeOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, attribute) + } + + implement_wrapper_operand_with_argument!( + greater_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!( + not_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + starts_with, + impl Into + ); + implement_wrapper_operand_with_argument!( + ends_with, + impl Into + ); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!( + is_not_in, + impl Into + ); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/operation.rs b/crates/medmodels-core/src/medrecord/querying/attributes/operation.rs new file mode 100644 index 00000000..71479dff --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/attributes/operation.rs @@ -0,0 +1,1357 @@ +use super::{ + operand::{ + MultipleAttributesComparisonOperand, MultipleAttributesOperand, + SingleAttributeComparisonOperand, SingleAttributeOperand, + }, + AttributesTreeOperand, BinaryArithmeticKind, GetAttributes, MultipleComparisonKind, + SingleComparisonKind, UnaryArithmeticKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{ + Abs, Contains, EndsWith, Lowercase, Mod, Pow, Slice, StartsWith, Trim, TrimEnd, + TrimStart, Uppercase, + }, + querying::{ + attributes::{MultipleKind, SingleKind}, + traits::{DeepClone, ReadWriteOrPanic}, + values::MultipleValuesOperand, + BoxedIterator, + }, + DataType, MedRecordAttribute, MedRecordValue, Wrapper, + }, + MedRecord, +}; +use itertools::Itertools; +use std::{ + cmp::Ordering, + collections::HashMap, + fmt::Display, + hash::Hash, + ops::{Add, Mul, Range, Sub}, +}; + +macro_rules! get_multiple_operand_attributes { + ($kind:ident, $attributes:expr) => { + match $kind { + MultipleKind::Max => Box::new(AttributesTreeOperation::get_max($attributes)?), + MultipleKind::Min => Box::new(AttributesTreeOperation::get_min($attributes)?), + MultipleKind::Count => Box::new(AttributesTreeOperation::get_count($attributes)?), + MultipleKind::Sum => Box::new(AttributesTreeOperation::get_sum($attributes)?), + MultipleKind::First => Box::new(AttributesTreeOperation::get_first($attributes)?), + MultipleKind::Last => Box::new(AttributesTreeOperation::get_last($attributes)?), + } + }; +} + +macro_rules! get_single_operand_attribute { + ($kind:ident, $attributes:expr) => { + match $kind { + SingleKind::Max => MultipleAttributesOperation::get_max($attributes)?.1, + SingleKind::Min => MultipleAttributesOperation::get_min($attributes)?.1, + SingleKind::Count => MultipleAttributesOperation::get_count($attributes), + SingleKind::Sum => MultipleAttributesOperation::get_sum($attributes)?, + SingleKind::First => MultipleAttributesOperation::get_first($attributes)?, + SingleKind::Last => MultipleAttributesOperation::get_last($attributes)?, + } + }; +} + +macro_rules! get_single_attribute_comparison_operand_attribute { + ($operand:ident, $medrecord:ident) => { + match $operand { + SingleAttributeComparisonOperand::Operand(operand) => { + let context = &operand.context.context.context; + let kind = &operand.context.kind; + + let comparison_attributes = context + .get_attributes($medrecord)? + .map(|attribute| (&0, attribute)); + + let comparison_attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + let kind = &operand.kind; + + get_single_operand_attribute!(kind, comparison_attributes) + } + SingleAttributeComparisonOperand::Attribute(attribute) => attribute.clone(), + } + }; +} + +#[derive(Debug, Clone)] +pub enum AttributesTreeOperation { + AttributesOperation { + operand: Wrapper, + }, + SingleAttributeComparisonOperation { + operand: SingleAttributeComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleAttributesComparisonOperation { + operand: MultipleAttributesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleAttributeComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for AttributesTreeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::AttributesOperation { operand } => Self::AttributesOperation { + operand: operand.deep_clone(), + }, + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::SingleAttributeComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::MultipleAttributesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl AttributesTreeOperation { + pub(crate) fn evaluate<'a, T: Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + ) -> MedRecordResult)>> { + match self { + Self::AttributesOperation { operand } => Ok(Box::new( + Self::evaluate_attributes_operation(medrecord, attributes, operand)?, + )), + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::evaluate_single_attribute_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_attributes_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, attributes, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(attributes, kind.clone()), + )), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(attributes, range.clone()))), + Self::IsString => Ok(Box::new(attributes.map(|(index, attribute)| { + ( + index, + attribute + .into_iter() + .filter(|attribute| matches!(attribute, MedRecordAttribute::String(_))) + .collect(), + ) + }))), + Self::IsInt => Ok(Box::new(attributes.map(|(index, attribute)| { + ( + index, + attribute + .into_iter() + .filter(|attribute| matches!(attribute, MedRecordAttribute::String(_))) + .collect(), + ) + }))), + Self::IsMax => { + let max_attributes = Self::get_max(attributes)?; + + Ok(Box::new( + max_attributes.map(|(index, attribute)| (index, vec![attribute])), + )) + } + Self::IsMin => { + let min_attributes = Self::get_min(attributes)?; + + Ok(Box::new( + min_attributes.map(|(index, attribute)| (index, vec![attribute])), + )) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, attributes, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes.map(|(index, attributes)| { + let mut attributes = attributes.into_iter(); + + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + let attribute = attributes.try_fold(first_attribute, |max, attribute| { + match attribute.partial_cmp(&max) { + Some(Ordering::Greater) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute); + let second_dtype = DataType::from(max); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max), + } + })?; + + Ok((index, attribute)) + }).collect::>>()?.into_iter()) + } + + #[inline] + pub(crate) fn get_min<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes.map(|(index, attributes)| { + let mut attributes = attributes.into_iter(); + + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + let attribute = attributes.try_fold(first_attribute, |max, attribute| { + match attribute.partial_cmp(&max) { + Some(Ordering::Less) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute); + let second_dtype = DataType::from(max); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max), + } + })?; + + Ok((index, attribute)) + }).collect::>>()?.into_iter()) + } + + #[inline] + pub(crate) fn get_count<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attribute)| (index, MedRecordAttribute::Int(attribute.len() as i64)))) + } + + #[inline] + pub(crate) fn get_sum<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes.map(|(index, attributes)| { + let mut attributes = attributes.into_iter(); + + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + let attribute = attributes.try_fold(first_attribute, |sum, attribute| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&attribute); + + sum.add(attribute).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + )) + }) + })?; + + Ok((index, attribute)) + }).collect::>>()?.into_iter()) + } + + #[inline] + pub(crate) fn get_first<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attributes)| { + let first_attribute = + attributes + .into_iter() + .next() + .ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + Ok((index, first_attribute)) + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + pub(crate) fn get_last<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attributes)| { + let first_attribute = + attributes + .into_iter() + .last() + .ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + Ok((index, first_attribute)) + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + fn evaluate_attributes_operation<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator)>, + operand: &Wrapper, + ) -> MedRecordResult)>> { + let kind = &operand.0.read_or_panic().kind; + + let attributes = attributes.collect::>(); + + let multiple_operand_attributes: Box> = + get_multiple_operand_attributes!(kind, attributes.clone().into_iter()); + + let result = operand.evaluate(medrecord, multiple_operand_attributes)?; + + let mut attributes = attributes.into_iter().collect::>(); + + Ok(result + .map(move |(index, _)| (index, attributes.remove(&index).expect("Index must exist")))) + } + + #[inline] + fn evaluate_single_attribute_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + comparison_operand: &SingleAttributeComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult)>> { + let comparison_attribute = + get_single_attribute_comparison_operand_attribute!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute > &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::GreaterThanOrEqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute >= &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::LessThan => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute < &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::LessThanOrEqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute <= &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::EqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute == &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::NotEqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute != &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::StartsWith => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute.starts_with(&comparison_attribute)) + .collect(), + ) + }))) + } + SingleComparisonKind::EndsWith => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute.ends_with(&comparison_attribute)) + .collect(), + ) + }))) + } + SingleComparisonKind::Contains => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute.contains(&comparison_attribute)) + .collect(), + ) + }))) + } + } + } + + #[inline] + fn evaluate_multiple_attributes_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + comparison_operand: &MultipleAttributesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult)>> { + let comparison_attributes = match comparison_operand { + MultipleAttributesComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + let comparison_attributes = context + .get_attributes(medrecord)? + .map(|attribute| (&0, attribute)); + + let comparison_attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + comparison_attributes + .map(|(_, attribute)| attribute) + .collect::>() + } + MultipleAttributesComparisonOperand::Attributes(attributes) => attributes.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| comparison_attributes.contains(attribute)) + .collect(), + ) + }))) + } + MultipleComparisonKind::IsNotIn => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| !comparison_attributes.contains(attribute)) + .collect(), + ) + }))) + } + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation<'a, T: 'a>( + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + operand: &SingleAttributeComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult)>> { + let arithmetic_attribute = + get_single_attribute_comparison_operand_attribute!(operand, medrecord); + + let attributes: Box< + dyn Iterator)>>, + > = match kind { + BinaryArithmeticKind::Add => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.add(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Sub => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.sub(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Mul => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.mul(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Pow => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.pow(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Mod => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.r#mod(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + }; + + Ok(Box::new( + attributes.collect::>>()?.into_iter(), + )) + } + + #[inline] + fn evaluate_unary_arithmetic_operation<'a, T: 'a>( + attributes: impl Iterator)>, + kind: UnaryArithmeticKind, + ) -> impl Iterator)> { + attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .map(|attribute| match kind { + UnaryArithmeticKind::Abs => attribute.abs(), + UnaryArithmeticKind::Trim => attribute.trim(), + UnaryArithmeticKind::TrimStart => attribute.trim_start(), + UnaryArithmeticKind::TrimEnd => attribute.trim_end(), + UnaryArithmeticKind::Lowercase => attribute.lowercase(), + UnaryArithmeticKind::Uppercase => attribute.uppercase(), + }) + .collect(), + ) + }) + } + + #[inline] + fn evaluate_slice<'a, T: 'a>( + attributes: impl Iterator)>, + range: Range, + ) -> impl Iterator)> { + attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .map(|attribute| attribute.slice(range.clone())) + .collect(), + ) + }) + } + + #[inline] + fn evaluate_either_or<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator)>, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult)>> { + let attributes = attributes.collect::>(); + + let either_attributes = either.evaluate(medrecord, attributes.clone().into_iter())?; + let or_attributes = or.evaluate(medrecord, attributes.into_iter())?; + + Ok(Box::new( + either_attributes + .chain(or_attributes) + .unique_by(|attribute| attribute.0), + )) + } +} + +#[derive(Debug, Clone)] +pub enum MultipleAttributesOperation { + AttributeOperation { + operand: Wrapper, + }, + SingleAttributeComparisonOperation { + operand: SingleAttributeComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleAttributesComparisonOperation { + operand: MultipleAttributesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleAttributeComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + ToValues { + operand: Wrapper, + }, + + Slice(Range), + + IsString, + IsInt, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for MultipleAttributesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::AttributeOperation { operand } => Self::AttributeOperation { + operand: operand.deep_clone(), + }, + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::SingleAttributeComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::MultipleAttributesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::ToValues { operand } => Self::ToValues { + operand: operand.deep_clone(), + }, + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl MultipleAttributesOperation { + pub(crate) fn evaluate<'a, T: Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::AttributeOperation { operand } => { + Self::evaluate_attribute_operation(medrecord, attributes, operand) + } + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::evaluate_single_attribute_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_attributes_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => Ok(Box::new( + Self::evaluate_binary_arithmetic_operation(medrecord, attributes, operand, kind)?, + )), + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(attributes, kind.clone()), + )), + Self::ToValues { operand } => Ok(Box::new(Self::evaluate_to_values( + medrecord, attributes, operand, + )?)), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(attributes, range.clone()))), + Self::IsString => { + Ok(Box::new(attributes.filter(|(_, attribute)| { + matches!(attribute, MedRecordAttribute::String(_)) + }))) + } + Self::IsInt => { + Ok(Box::new(attributes.filter(|(_, attribute)| { + matches!(attribute, MedRecordAttribute::Int(_)) + }))) + } + Self::IsMax => { + let max_attribute = Self::get_max(attributes)?; + + Ok(Box::new(std::iter::once(max_attribute))) + } + Self::IsMin => { + let min_attribute = Self::get_min(attributes)?; + + Ok(Box::new(std::iter::once(min_attribute))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, attributes, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max<'a, T>( + mut attributes: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordAttribute)> { + let max_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + attributes.try_fold(max_attribute, |max_attribute, attribute| { + match attribute.1.partial_cmp(&max_attribute.1) { + Some(Ordering::Greater) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute.1); + let second_dtype = DataType::from(max_attribute.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max_attribute), + } + }) + } + + #[inline] + pub(crate) fn get_min<'a, T>( + mut attributes: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordAttribute)> { + let min_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + attributes.try_fold(min_attribute, |min_attribute, attribute| { + match attribute.1.partial_cmp(&min_attribute.1) { + Some(Ordering::Less) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute.1); + let second_dtype = DataType::from(min_attribute.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(min_attribute), + } + }) + } + + #[inline] + pub(crate) fn get_count<'a, T: 'a>( + attributes: impl Iterator, + ) -> MedRecordAttribute { + MedRecordAttribute::Int(attributes.count() as i64) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_sum<'a, T: 'a>( + mut attributes: impl Iterator, + ) -> MedRecordResult { + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + attributes.try_fold(first_attribute.1, |sum, (_, attribute)| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&attribute); + + sum.add(attribute).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + )) + }) + }) + } + + #[inline] + pub(crate) fn get_first<'a, T: 'a>( + mut attributes: impl Iterator, + ) -> MedRecordResult { + attributes + .next() + .ok_or(MedRecordError::QueryError( + "No attributes to get the first".to_string(), + )) + .map(|(_, attribute)| attribute) + } + + #[inline] + pub(crate) fn get_last<'a, T: 'a>( + attributes: impl Iterator, + ) -> MedRecordResult { + attributes + .last() + .ok_or(MedRecordError::QueryError( + "No attributes to get the first".to_string(), + )) + .map(|(_, attribute)| attribute) + } + + #[inline] + fn evaluate_attribute_operation<'a, T>( + medrecord: &'a MedRecord, + attribtues: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let attributes = attribtues.collect::>(); + + let attribute = get_single_operand_attribute!(kind, attributes.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, attribute)? { + Some(_) => Box::new(attributes.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_single_attribute_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + comparison_operand: &SingleAttributeComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_attribute = + get_single_attribute_comparison_operand_attribute!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute > &comparison_attribute + }))) + } + SingleComparisonKind::GreaterThanOrEqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute >= &comparison_attribute + }))) + } + SingleComparisonKind::LessThan => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute < &comparison_attribute + }))) + } + SingleComparisonKind::LessThanOrEqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute <= &comparison_attribute + }))) + } + SingleComparisonKind::EqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute == &comparison_attribute + }))) + } + SingleComparisonKind::NotEqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute != &comparison_attribute + }))) + } + SingleComparisonKind::StartsWith => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute.starts_with(&comparison_attribute) + }))) + } + SingleComparisonKind::EndsWith => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute.ends_with(&comparison_attribute) + }))) + } + SingleComparisonKind::Contains => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute.contains(&comparison_attribute) + }))) + } + } + } + + #[inline] + fn evaluate_multiple_attributes_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + comparison_operand: &MultipleAttributesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_attributes = match comparison_operand { + MultipleAttributesComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + let comparison_attributes = context + .get_attributes(medrecord)? + .map(|attribute| (&0, attribute)); + + let attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + attributes + .map(|(_, attribute)| attribute) + .collect::>() + } + MultipleAttributesComparisonOperand::Attributes(attributes) => attributes.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + comparison_attributes.contains(attribute) + }))) + } + MultipleComparisonKind::IsNotIn => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + !comparison_attributes.contains(attribute) + }))) + } + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation<'a, T: 'a>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + operand: &SingleAttributeComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_attribute = + get_single_attribute_comparison_operand_attribute!(operand, medrecord); + + let attributes = attributes + .map(move |(t, attribute)| { + match kind { + BinaryArithmeticKind::Add => attribute.add(arithmetic_attribute.clone()), + BinaryArithmeticKind::Sub => attribute.sub(arithmetic_attribute.clone()), + BinaryArithmeticKind::Mul => { + attribute.clone().mul(arithmetic_attribute.clone()) + } + BinaryArithmeticKind::Pow => { + attribute.clone().pow(arithmetic_attribute.clone()) + } + BinaryArithmeticKind::Mod => { + attribute.clone().r#mod(arithmetic_attribute.clone()) + } + } + .map_err(|_| { + MedRecordError::QueryError(format!( + "Failed arithmetic operation {}. Consider narrowing down the attributes using .is_int() or .is_float()", + kind, + )) + }).map(|result| (t, result)) + }); + + // TODO: This is a temporary solution. It should be optimized. + Ok(attributes.collect::>>()?.into_iter()) + } + + #[inline] + fn evaluate_unary_arithmetic_operation<'a, T: 'a>( + attributes: impl Iterator, + kind: UnaryArithmeticKind, + ) -> impl Iterator { + attributes.map(move |(t, attribute)| { + let attribute = match kind { + UnaryArithmeticKind::Abs => attribute.abs(), + UnaryArithmeticKind::Trim => attribute.trim(), + UnaryArithmeticKind::TrimStart => attribute.trim_start(), + UnaryArithmeticKind::TrimEnd => attribute.trim_end(), + UnaryArithmeticKind::Lowercase => attribute.lowercase(), + UnaryArithmeticKind::Uppercase => attribute.uppercase(), + }; + (t, attribute) + }) + } + + pub(crate) fn get_values<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attribute)| { + let value = index.get_attributes(medrecord)?.get(&attribute).ok_or( + MedRecordError::QueryError(format!( + "Cannot find attribute {} for index {}", + attribute, index + )), + )?; + + Ok((index, value.clone())) + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + fn evaluate_to_values<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let attributes = attributes.collect::>(); + + let values = Self::get_values(medrecord, attributes.clone().into_iter())?; + + let mut attributes = attributes.into_iter().collect::>(); + + let values = operand.evaluate(medrecord, values.into_iter())?; + + Ok(values.map(move |(index, _)| { + ( + index, + attributes.remove(&index).expect("Attribute must exist"), + ) + })) + } + + #[inline] + fn evaluate_slice<'a, T: 'a>( + attributes: impl Iterator, + range: Range, + ) -> impl Iterator { + attributes.map(move |(t, attribute)| (t, attribute.slice(range.clone()))) + } + + #[inline] + fn evaluate_either_or<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let attributes = attributes.collect::>(); + + let either_attributes = either.evaluate(medrecord, attributes.clone().into_iter())?; + let or_attributes = or.evaluate(medrecord, attributes.into_iter())?; + + Ok(Box::new( + either_attributes + .chain(or_attributes) + .unique_by(|attribute| attribute.0), + )) + } +} + +#[derive(Debug, Clone)] +pub enum SingleAttributeOperation { + SingleAttributeComparisonOperation { + operand: SingleAttributeComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleAttributesComparisonOperation { + operand: MultipleAttributesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleAttributeComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for SingleAttributeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::SingleAttributeComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::MultipleAttributesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl SingleAttributeOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + match self { + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::evaluate_single_attribute_comparison_operation( + medrecord, attribute, operand, kind, + ) + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_attribute_comparison_operation( + medrecord, attribute, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, attribute, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Some(match kind { + UnaryArithmeticKind::Abs => attribute.abs(), + UnaryArithmeticKind::Trim => attribute.trim(), + UnaryArithmeticKind::TrimStart => attribute.trim_start(), + UnaryArithmeticKind::TrimEnd => attribute.trim_end(), + UnaryArithmeticKind::Lowercase => attribute.lowercase(), + UnaryArithmeticKind::Uppercase => attribute.uppercase(), + })), + Self::Slice(range) => Ok(Some(attribute.slice(range.clone()))), + Self::IsString => Ok(match attribute { + MedRecordAttribute::String(_) => Some(attribute), + _ => None, + }), + Self::IsInt => Ok(match attribute { + MedRecordAttribute::Int(_) => Some(attribute), + _ => None, + }), + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, attribute, either, or) + } + } + } + + #[inline] + fn evaluate_single_attribute_comparison_operation( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + comparison_operand: &SingleAttributeComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_attribute = + get_single_attribute_comparison_operand_attribute!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => attribute > comparison_attribute, + SingleComparisonKind::GreaterThanOrEqualTo => attribute >= comparison_attribute, + SingleComparisonKind::LessThan => attribute < comparison_attribute, + SingleComparisonKind::LessThanOrEqualTo => attribute <= comparison_attribute, + SingleComparisonKind::EqualTo => attribute == comparison_attribute, + SingleComparisonKind::NotEqualTo => attribute != comparison_attribute, + SingleComparisonKind::StartsWith => attribute.starts_with(&comparison_attribute), + SingleComparisonKind::EndsWith => attribute.ends_with(&comparison_attribute), + SingleComparisonKind::Contains => attribute.contains(&comparison_attribute), + }; + + Ok(if comparison_result { + Some(attribute) + } else { + None + }) + } + + #[inline] + fn evaluate_multiple_attribute_comparison_operation( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + comparison_operand: &MultipleAttributesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_attributes = match comparison_operand { + MultipleAttributesComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + let comparison_attributes = context + .get_attributes(medrecord)? + .map(|attribute| (&0, attribute)); + + let attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + attributes + .map(|(_, attribute)| attribute) + .collect::>() + } + MultipleAttributesComparisonOperand::Attributes(attributes) => attributes.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_attributes.contains(&attribute), + MultipleComparisonKind::IsNotIn => !comparison_attributes.contains(&attribute), + }; + + Ok(if comparison_result { + Some(attribute) + } else { + None + }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + operand: &SingleAttributeComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_attribute = + get_single_attribute_comparison_operand_attribute!(operand, medrecord); + + match kind { + BinaryArithmeticKind::Add => attribute.add(arithmetic_attribute), + BinaryArithmeticKind::Sub => attribute.sub(arithmetic_attribute), + BinaryArithmeticKind::Mul => attribute.mul(arithmetic_attribute), + BinaryArithmeticKind::Pow => attribute.pow(arithmetic_attribute), + BinaryArithmeticKind::Mod => attribute.r#mod(arithmetic_attribute), + } + .map(Some) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, attribute.clone())?; + let or_result = or.evaluate(medrecord, attribute)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/mod.rs b/crates/medmodels-core/src/medrecord/querying/edges/mod.rs index 19909a12..1045e83e 100644 --- a/crates/medmodels-core/src/medrecord/querying/edges/mod.rs +++ b/crates/medmodels-core/src/medrecord/querying/edges/mod.rs @@ -1,8 +1,58 @@ mod operand; mod operation; mod selection; -mod values; pub use operand::EdgeOperand; pub use operation::EdgeOperation; pub use selection::EdgeSelection; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/operand.rs b/crates/medmodels-core/src/medrecord/querying/edges/operand.rs index 42608881..4b7b4f85 100644 --- a/crates/medmodels-core/src/medrecord/querying/edges/operand.rs +++ b/crates/medmodels-core/src/medrecord/querying/edges/operand.rs @@ -1,12 +1,17 @@ -use super::operation::EdgeOperation; +use super::{ + operation::{EdgeIndexOperation, EdgeIndicesOperation, EdgeOperation}, + BinaryArithmeticKind, MultipleComparisonKind, SingleComparisonKind, SingleKind, +}; use crate::{ errors::MedRecordResult, medrecord::{ querying::{ + attributes::{self, AttributesTreeOperand}, nodes::NodeOperand, traits::{DeepClone, ReadWriteOrPanic}, - values::{Context, MedRecordValuesOperand}, + values::{self, MultipleValuesOperand}, wrapper::Wrapper, + BoxedIterator, }, CardinalityWrapper, EdgeIndex, Group, MedRecordAttribute, }, @@ -42,8 +47,7 @@ impl EdgeOperand { &self, medrecord: &'a MedRecord, ) -> MedRecordResult> { - let edge_indices = - Box::new(medrecord.edge_indices()) as Box>; + let edge_indices = Box::new(medrecord.edge_indices()) as BoxedIterator<&'a EdgeIndex>; self.operations .iter() @@ -52,13 +56,35 @@ impl EdgeOperand { }) } - pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { - let operand = Wrapper::::new( - Context::EdgeOperand(self.deep_clone()), + pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { + let operand = Wrapper::::new( + values::Context::EdgeOperand(self.deep_clone()), attribute, ); - self.operations.push(EdgeOperation::Attribute { + self.operations.push(EdgeOperation::Values { + operand: operand.clone(), + }); + + operand + } + + pub fn attributes(&mut self) -> Wrapper { + let operand = Wrapper::::new(attributes::Context::EdgeOperand( + self.deep_clone(), + )); + + self.operations.push(EdgeOperation::Attributes { + operand: operand.clone(), + }); + + operand + } + + pub fn index(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone()); + + self.operations.push(EdgeOperation::Indices { operand: operand.clone(), }); @@ -102,6 +128,23 @@ impl EdgeOperand { operand } + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(); + let mut or_operand = Wrapper::::new(); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(EdgeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } } impl Wrapper { @@ -110,19 +153,27 @@ impl Wrapper { } pub(crate) fn evaluate<'a>( - &'a self, + &self, medrecord: &'a MedRecord, ) -> MedRecordResult> { self.0.read_or_panic().evaluate(medrecord) } - pub fn attribute(&self, attribute: A) -> Wrapper + pub fn attribute(&self, attribute: A) -> Wrapper where A: Into, { self.0.write_or_panic().attribute(attribute.into()) } + pub fn attributes(&self) -> Wrapper { + self.0.write_or_panic().attributes() + } + + pub fn index(&self) -> Wrapper { + self.0.write_or_panic().index() + } + pub fn in_group(&mut self, group: G) where G: Into>, @@ -144,4 +195,461 @@ impl Wrapper { pub fn target_node(&self) -> Wrapper { self.0.write_or_panic().target_node() } + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +macro_rules! implement_value_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(EdgeIndicesOperation::EdgeIndexOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_value_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations + .push($operation::EdgeIndexComparisonOperation { + operand: value.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: value.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $value_type:ty) => { + pub fn $name(&self, value: $value_type) { + self.0.write_or_panic().$name(value) + } + }; +} + +#[derive(Debug, Clone)] +pub enum EdgeIndexComparisonOperand { + Operand(EdgeIndexOperand), + Index(EdgeIndex), +} + +impl DeepClone for EdgeIndexComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Index(value) => Self::Index(*value), + } + } +} + +impl From> for EdgeIndexComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for EdgeIndexComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for EdgeIndexComparisonOperand { + fn from(value: V) -> Self { + Self::Index(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum EdgeIndicesComparisonOperand { + Operand(EdgeIndicesOperand), + Indices(Vec), +} + +impl DeepClone for EdgeIndicesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Indices(value) => Self::Indices(value.clone()), + } + } +} + +impl From> for EdgeIndicesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for EdgeIndicesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for EdgeIndicesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Indices(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> for EdgeIndicesComparisonOperand { + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct EdgeIndicesOperand { + pub(crate) context: EdgeOperand, + operations: Vec, +} + +impl DeepClone for EdgeIndicesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl EdgeIndicesOperand { + pub(crate) fn new(context: EdgeOperand) -> Self { + Self { + context, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + let values = Box::new(values) as BoxedIterator; + + self.operations + .iter() + .try_fold(values, |value_tuples, operation| { + operation.evaluate(medrecord, value_tuples) + }) + } + + implement_value_operation!(max, Max); + implement_value_operation!(min, Min); + implement_value_operation!(count, Count); + implement_value_operation!(sum, Sum); + implement_value_operation!(first, First); + implement_value_operation!(last, Last); + + implement_single_value_comparison_operation!(greater_than, EdgeIndicesOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + EdgeIndicesOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, EdgeIndicesOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + EdgeIndicesOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, EdgeIndicesOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, EdgeIndicesOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, EdgeIndicesOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, EdgeIndicesOperation, EndsWith); + implement_single_value_comparison_operation!(contains, EdgeIndicesOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(EdgeIndicesOperation::EdgeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(EdgeIndicesOperation::EdgeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, EdgeIndicesOperation, Add); + implement_binary_arithmetic_operation!(sub, EdgeIndicesOperation, Sub); + implement_binary_arithmetic_operation!(mul, EdgeIndicesOperation, Mul); + implement_binary_arithmetic_operation!(pow, EdgeIndicesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, EdgeIndicesOperation, Mod); + + implement_assertion_operation!(is_max, EdgeIndicesOperation::IsMax); + implement_assertion_operation!(is_min, EdgeIndicesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(self.context.clone()); + let mut or_operand = Wrapper::::new(self.context.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(EdgeIndicesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: EdgeOperand) -> Self { + EdgeIndicesOperand::new(context).into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + self.0.read_or_panic().evaluate(medrecord, values) + } + + implement_wrapper_operand_with_return!(max, EdgeIndexOperand); + implement_wrapper_operand_with_return!(min, EdgeIndexOperand); + implement_wrapper_operand_with_return!(count, EdgeIndexOperand); + implement_wrapper_operand_with_return!(sum, EdgeIndexOperand); + implement_wrapper_operand_with_return!(first, EdgeIndexOperand); + implement_wrapper_operand_with_return!(last, EdgeIndexOperand); + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct EdgeIndexOperand { + pub(crate) context: EdgeIndicesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for EdgeIndexOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl EdgeIndexOperand { + pub(crate) fn new(context: EdgeIndicesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: EdgeIndex, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(value), |value, operation| { + if let Some(value) = value { + operation.evaluate(medrecord, value) + } else { + Ok(None) + } + }) + } + + implement_single_value_comparison_operation!(greater_than, EdgeIndexOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + EdgeIndexOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, EdgeIndexOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + EdgeIndexOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, EdgeIndexOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, EdgeIndexOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, EdgeIndexOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, EdgeIndexOperation, EndsWith); + implement_single_value_comparison_operation!(contains, EdgeIndexOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(EdgeIndexOperation::EdgeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(EdgeIndexOperation::EdgeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, EdgeIndexOperation, Add); + implement_binary_arithmetic_operation!(sub, EdgeIndexOperation, Sub); + implement_binary_arithmetic_operation!(mul, EdgeIndexOperation, Mul); + implement_binary_arithmetic_operation!(pow, EdgeIndexOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, EdgeIndexOperation, Mod); + + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(EdgeIndexOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: EdgeIndicesOperand, kind: SingleKind) -> Self { + EdgeIndexOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: EdgeIndex, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, value) + } + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } } diff --git a/crates/medmodels-core/src/medrecord/querying/edges/operation.rs b/crates/medmodels-core/src/medrecord/querying/edges/operation.rs index b8f6940e..0d36db8f 100644 --- a/crates/medmodels-core/src/medrecord/querying/edges/operation.rs +++ b/crates/medmodels-core/src/medrecord/querying/edges/operation.rs @@ -1,22 +1,43 @@ +use super::{ + operand::{ + EdgeIndexComparisonOperand, EdgeIndexOperand, EdgeIndicesComparisonOperand, + EdgeIndicesOperand, + }, + BinaryArithmeticKind, EdgeOperand, MultipleComparisonKind, SingleComparisonKind, +}; use crate::{ - errors::MedRecordResult, + errors::{MedRecordError, MedRecordResult}, medrecord::{ + datatypes::{Contains, EndsWith, Mod, StartsWith}, querying::{ + attributes::AttributesTreeOperand, + edges::SingleKind, nodes::NodeOperand, traits::{DeepClone, ReadWriteOrPanic}, - values::MedRecordValuesOperand, + values::MultipleValuesOperand, wrapper::Wrapper, + BoxedIterator, }, CardinalityWrapper, EdgeIndex, Group, MedRecordAttribute, MedRecordValue, }, MedRecord, }; -use std::collections::HashSet; +use itertools::Itertools; +use std::{ + collections::HashSet, + ops::{Add, Mul, Sub}, +}; #[derive(Debug, Clone)] pub enum EdgeOperation { - Attribute { - operand: Wrapper, + Values { + operand: Wrapper, + }, + Attributes { + operand: Wrapper, + }, + Indices { + operand: Wrapper, }, InGroup { @@ -32,12 +53,23 @@ pub enum EdgeOperation { TargetNode { operand: Wrapper, }, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, } impl DeepClone for EdgeOperation { fn deep_clone(&self) -> Self { match self { - Self::Attribute { operand } => Self::Attribute { + Self::Values { operand } => Self::Values { + operand: operand.deep_clone(), + }, + Self::Attributes { operand } => Self::Attributes { + operand: operand.deep_clone(), + }, + Self::Indices { operand } => Self::Indices { operand: operand.deep_clone(), }, Self::InGroup { group } => Self::InGroup { @@ -52,6 +84,10 @@ impl DeepClone for EdgeOperation { Self::TargetNode { operand } => Self::TargetNode { operand: operand.deep_clone(), }, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, } } } @@ -61,9 +97,19 @@ impl EdgeOperation { &self, medrecord: &'a MedRecord, edge_indices: impl Iterator + 'a, - ) -> MedRecordResult + 'a>> { + ) -> MedRecordResult> { Ok(match self { - Self::Attribute { operand } => Box::new(Self::evaluate_attribute( + Self::Values { operand } => Box::new(Self::evaluate_values( + medrecord, + edge_indices, + operand.clone(), + )?), + Self::Attributes { operand } => Box::new(Self::evaluate_attributes( + medrecord, + edge_indices, + operand.clone(), + )?), + Self::Indices { operand } => Box::new(Self::evaluate_indices( medrecord, edge_indices, operand.clone(), @@ -88,31 +134,35 @@ impl EdgeOperation { edge_indices, operand, )?), + Self::EitherOr { either, or } => { + Box::new(Self::evaluate_either_or(medrecord, either, or)?) + } }) } #[inline] pub(crate) fn get_values<'a>( medrecord: &'a MedRecord, - edge_indices: impl Iterator + 'a, + edge_indices: impl Iterator, attribute: MedRecordAttribute, - ) -> impl Iterator + 'a { + ) -> impl Iterator { edge_indices.flat_map(move |edge_index| { Some(( edge_index, medrecord .edge_attributes(edge_index) .expect("Edge must exist") - .get(&attribute)?, + .get(&attribute)? + .clone(), )) }) } #[inline] - fn evaluate_attribute<'a>( + fn evaluate_values<'a>( medrecord: &'a MedRecord, edge_indices: impl Iterator + 'a, - operand: Wrapper, + operand: Wrapper, ) -> MedRecordResult> { let values = Self::get_values( medrecord, @@ -124,11 +174,58 @@ impl EdgeOperation { } #[inline] - fn evaluate_in_group<'a>( + pub(crate) fn get_attributes<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + ) -> impl Iterator)> { + edge_indices.map(move |edge_index| { + let attributes = medrecord + .edge_attributes(edge_index) + .expect("Edge must exist") + .keys() + .cloned(); + + (edge_index, attributes.collect()) + }) + } + + #[inline] + fn evaluate_attributes<'a>( medrecord: &'a MedRecord, edge_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let attributes = Self::get_attributes(medrecord, edge_indices); + + Ok(operand + .evaluate(medrecord, attributes)? + .map(|value| value.0)) + } + + #[inline] + fn evaluate_indices<'a>( + medrecord: &MedRecord, + edge_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + // TODO: This is a temporary solution. It should be optimized. + let edge_indices = edge_indices.collect::>(); + + let result = operand + .evaluate(medrecord, edge_indices.clone().into_iter().cloned())? + .collect::>(); + + Ok(edge_indices + .into_iter() + .filter(move |index| result.contains(index))) + } + + #[inline] + fn evaluate_in_group<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, group: CardinalityWrapper, - ) -> impl Iterator + 'a { + ) -> impl Iterator { edge_indices.filter(move |edge_index| { let groups_of_edge = medrecord .groups_of_edge(edge_index) @@ -148,9 +245,9 @@ impl EdgeOperation { #[inline] fn evaluate_has_attribute<'a>( medrecord: &'a MedRecord, - edge_indices: impl Iterator + 'a, + edge_indices: impl Iterator, attribute: CardinalityWrapper, - ) -> impl Iterator + 'a { + ) -> impl Iterator { edge_indices.filter(move |edge_index| { let attributes_of_edge = medrecord .edge_attributes(edge_index) @@ -201,4 +298,465 @@ impl EdgeOperation { node_indices.contains(edge_endpoints.1) })) } + + #[inline] + fn evaluate_either_or<'a>( + medrecord: &'a MedRecord, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord)?; + let or_result = or.evaluate(medrecord)?; + + Ok(either_result.chain(or_result).unique()) + } +} + +macro_rules! get_edge_index { + ($kind:ident, $indices:expr) => { + match $kind { + SingleKind::Max => EdgeIndicesOperation::get_max($indices)?.clone(), + SingleKind::Min => EdgeIndicesOperation::get_min($indices)?.clone(), + SingleKind::Count => EdgeIndicesOperation::get_count($indices), + SingleKind::Sum => EdgeIndicesOperation::get_sum($indices), + SingleKind::First => EdgeIndicesOperation::get_first($indices)?, + SingleKind::Last => EdgeIndicesOperation::get_last($indices)?, + } + }; +} + +macro_rules! get_edge_index_comparison_operand_index { + ($operand:ident, $medrecord:ident) => { + match $operand { + EdgeIndexComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + // TODO: This is a temporary solution. It should be optimized. + let comparison_indices = context.evaluate($medrecord)?.cloned(); + + let comparison_index = get_edge_index!(kind, comparison_indices); + + comparison_index + } + EdgeIndexComparisonOperand::Index(index) => index.clone(), + } + }; +} + +#[derive(Debug, Clone)] +pub enum EdgeIndicesOperation { + EdgeIndexOperation { + operand: Wrapper, + }, + EdgeIndexComparisonOperation { + operand: EdgeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + EdgeIndicesComparisonOperation { + operand: EdgeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: EdgeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for EdgeIndicesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::EdgeIndexOperation { operand } => Self::EdgeIndexOperation { + operand: operand.deep_clone(), + }, + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::EdgeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::EdgeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl EdgeIndicesOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + indices: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::EdgeIndexOperation { operand } => { + Self::evaluate_edge_index_operation(medrecord, indices, operand) + } + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::evaluate_edge_index_comparison_operation(medrecord, indices, operand, kind) + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_edge_indices_comparison_operation(medrecord, indices, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Ok(Box::new(Self::evaluate_binary_arithmetic_operation( + medrecord, + indices, + operand, + kind.clone(), + )?)) + } + Self::IsMax => { + let max_index = Self::get_max(indices)?; + + Ok(Box::new(std::iter::once(max_index))) + } + Self::IsMin => { + let min_index = Self::get_min(indices)?; + + Ok(Box::new(std::iter::once(min_index))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, indices, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max(indices: impl Iterator) -> MedRecordResult { + indices.max().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + )) + } + + #[inline] + pub(crate) fn get_min(indices: impl Iterator) -> MedRecordResult { + indices.min().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + )) + } + #[inline] + pub(crate) fn get_count(indices: impl Iterator) -> EdgeIndex { + indices.count() as EdgeIndex + } + + #[inline] + pub(crate) fn get_sum(indices: impl Iterator) -> EdgeIndex { + indices.sum() + } + + #[inline] + pub(crate) fn get_first( + mut indices: impl Iterator, + ) -> MedRecordResult { + indices.next().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + pub(crate) fn get_last(indices: impl Iterator) -> MedRecordResult { + indices.last().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + fn evaluate_edge_index_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let indices = indices.collect::>(); + + let index = get_edge_index!(kind, indices.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, index)? { + Some(_) => Box::new(indices.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_edge_index_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &EdgeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = + get_edge_index_comparison_operand_index!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => Ok(Box::new( + indices.filter(move |index| index > &comparison_index), + )), + SingleComparisonKind::GreaterThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index >= &comparison_index), + )), + SingleComparisonKind::LessThan => Ok(Box::new( + indices.filter(move |index| index < &comparison_index), + )), + SingleComparisonKind::LessThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index <= &comparison_index), + )), + SingleComparisonKind::EqualTo => Ok(Box::new( + indices.filter(move |index| index == &comparison_index), + )), + SingleComparisonKind::NotEqualTo => Ok(Box::new( + indices.filter(move |index| index != &comparison_index), + )), + SingleComparisonKind::StartsWith => Ok(Box::new( + indices.filter(move |index| index.starts_with(&comparison_index)), + )), + SingleComparisonKind::EndsWith => Ok(Box::new( + indices.filter(move |index| index.ends_with(&comparison_index)), + )), + SingleComparisonKind::Contains => Ok(Box::new( + indices.filter(move |index| index.contains(&comparison_index)), + )), + } + } + + #[inline] + fn evaluate_edge_indices_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &EdgeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + EdgeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + context.evaluate(medrecord)?.cloned().collect::>() + } + EdgeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => Ok(Box::new( + indices.filter(move |index| comparison_indices.contains(index)), + )), + MultipleComparisonKind::IsNotIn => Ok(Box::new( + indices.filter(move |index| !comparison_indices.contains(index)), + )), + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &EdgeIndexComparisonOperand, + kind: BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_edge_index_comparison_operand_index!(operand, medrecord); + + Ok(indices + .map(move |index| match kind { + BinaryArithmeticKind::Add => Ok(index.add(arithmetic_index)), + BinaryArithmeticKind::Sub => Ok(index.sub(arithmetic_index)), + BinaryArithmeticKind::Mul => Ok(index.mul(arithmetic_index)), + BinaryArithmeticKind::Pow => Ok(index.pow(arithmetic_index)), + BinaryArithmeticKind::Mod => index.r#mod(arithmetic_index), + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + fn evaluate_either_or<'a>( + medrecord: &'a MedRecord, + indices: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let indices = indices.collect::>(); + + let either_indices = either.evaluate(medrecord, indices.clone().into_iter())?; + let or_indices = or.evaluate(medrecord, indices.into_iter())?; + + Ok(Box::new(either_indices.chain(or_indices).unique())) + } +} + +#[derive(Debug, Clone)] +pub enum EdgeIndexOperation { + EdgeIndexComparisonOperation { + operand: EdgeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + EdgeIndicesComparisonOperation { + operand: EdgeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: EdgeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for EdgeIndexOperation { + fn deep_clone(&self) -> Self { + match self { + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::EdgeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::EdgeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl EdgeIndexOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + index: EdgeIndex, + ) -> MedRecordResult> { + match self { + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::evaluate_edge_index_comparison_operation(medrecord, index, operand, kind) + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_edge_indcies_comparison_operation(medrecord, index, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, index, operand, kind) + } + Self::EitherOr { either, or } => Self::evaluate_either_or(medrecord, index, either, or), + } + } + + #[inline] + fn evaluate_edge_index_comparison_operation( + medrecord: &MedRecord, + index: EdgeIndex, + comparison_operand: &EdgeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = + get_edge_index_comparison_operand_index!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => index > comparison_index, + SingleComparisonKind::GreaterThanOrEqualTo => index >= comparison_index, + SingleComparisonKind::LessThan => index < comparison_index, + SingleComparisonKind::LessThanOrEqualTo => index <= comparison_index, + SingleComparisonKind::EqualTo => index == comparison_index, + SingleComparisonKind::NotEqualTo => index != comparison_index, + SingleComparisonKind::StartsWith => index.starts_with(&comparison_index), + SingleComparisonKind::EndsWith => index.ends_with(&comparison_index), + SingleComparisonKind::Contains => index.contains(&comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_edge_indcies_comparison_operation( + medrecord: &MedRecord, + index: EdgeIndex, + comparison_operand: &EdgeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + EdgeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + context.evaluate(medrecord)?.cloned().collect::>() + } + EdgeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_indices + .into_iter() + .any(|comparison_index| index == comparison_index), + MultipleComparisonKind::IsNotIn => comparison_indices + .into_iter() + .all(|comparison_index| index != comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + index: EdgeIndex, + operand: &EdgeIndexComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_edge_index_comparison_operand_index!(operand, medrecord); + + Ok(Some(match kind { + BinaryArithmeticKind::Add => index.add(arithmetic_index), + BinaryArithmeticKind::Sub => index.sub(arithmetic_index), + BinaryArithmeticKind::Mul => index.mul(arithmetic_index), + BinaryArithmeticKind::Pow => index.pow(arithmetic_index), + BinaryArithmeticKind::Mod => index.r#mod(arithmetic_index)?, + })) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + index: EdgeIndex, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, index)?; + let or_result = or.evaluate(medrecord, index)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } } diff --git a/crates/medmodels-core/src/medrecord/querying/edges/values.rs b/crates/medmodels-core/src/medrecord/querying/edges/values.rs deleted file mode 100644 index 8b137891..00000000 --- a/crates/medmodels-core/src/medrecord/querying/edges/values.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/crates/medmodels-core/src/medrecord/querying/mod.rs b/crates/medmodels-core/src/medrecord/querying/mod.rs index 12255dce..94728fe4 100644 --- a/crates/medmodels-core/src/medrecord/querying/mod.rs +++ b/crates/medmodels-core/src/medrecord/querying/mod.rs @@ -1,5 +1,8 @@ +pub mod attributes; pub mod edges; pub mod nodes; mod traits; pub mod values; pub mod wrapper; + +pub(crate) type BoxedIterator<'a, T> = Box + 'a>; diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs b/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs index 05f88e18..1041a7e9 100644 --- a/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs +++ b/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs @@ -1,8 +1,68 @@ mod operand; mod operation; mod selection; -mod values; pub use operand::NodeOperand; pub use operation::NodeOperation; pub use selection::NodeSelection; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} + +#[derive(Debug, Clone)] +pub enum UnaryArithmeticKind { + Abs, + Trim, + TrimStart, + TrimEnd, + Lowercase, + Uppercase, +} diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs b/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs index 643c8f7a..1800bc00 100644 --- a/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs +++ b/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs @@ -1,12 +1,18 @@ -use super::operation::NodeOperation; +use super::{ + operation::{EdgeDirection, NodeIndexOperation, NodeIndicesOperation, NodeOperation}, + BinaryArithmeticKind, MultipleComparisonKind, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; use crate::{ errors::MedRecordResult, medrecord::{ querying::{ + attributes::{self, AttributesTreeOperand}, edges::EdgeOperand, traits::{DeepClone, ReadWriteOrPanic}, - values::{Context, MedRecordValuesOperand}, + values::{self, MultipleValuesOperand}, wrapper::{CardinalityWrapper, Wrapper}, + BoxedIterator, }, Group, MedRecordAttribute, NodeIndex, }, @@ -41,9 +47,8 @@ impl NodeOperand { pub(crate) fn evaluate<'a>( &self, medrecord: &'a MedRecord, - ) -> MedRecordResult + 'a>> { - let node_indices = - Box::new(medrecord.node_indices()) as Box>; + ) -> MedRecordResult> { + let node_indices = Box::new(medrecord.node_indices()) as BoxedIterator<'a, &'a NodeIndex>; self.operations .iter() @@ -52,13 +57,35 @@ impl NodeOperand { }) } - pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { - let operand = Wrapper::::new( - Context::NodeOperand(self.deep_clone()), + pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { + let operand = Wrapper::::new( + values::Context::NodeOperand(self.deep_clone()), attribute, ); - self.operations.push(NodeOperation::Attribute { + self.operations.push(NodeOperation::Values { + operand: operand.clone(), + }); + + operand + } + + pub fn attributes(&mut self) -> Wrapper { + let operand = Wrapper::::new(attributes::Context::NodeOperand( + self.deep_clone(), + )); + + self.operations.push(NodeOperation::Attributes { + operand: operand.clone(), + }); + + operand + } + + pub fn index(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone()); + + self.operations.push(NodeOperation::Indices { operand: operand.clone(), }); @@ -102,6 +129,34 @@ impl NodeOperand { operand } + + pub fn neighbors(&mut self, direction: EdgeDirection) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(NodeOperation::Neighbors { + operand: operand.clone(), + direction, + }); + + operand + } + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(); + let mut or_operand = Wrapper::::new(); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(NodeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } } impl Wrapper { @@ -112,14 +167,22 @@ impl Wrapper { pub(crate) fn evaluate<'a>( &self, medrecord: &'a MedRecord, - ) -> MedRecordResult + 'a>> { + ) -> MedRecordResult> { self.0.read_or_panic().evaluate(medrecord) } - pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { + pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { self.0.write_or_panic().attribute(attribute) } + pub fn attributes(&mut self) -> Wrapper { + self.0.write_or_panic().attributes() + } + + pub fn index(&mut self) -> Wrapper { + self.0.write_or_panic().index() + } + pub fn in_group(&mut self, group: G) where G: Into>, @@ -141,4 +204,529 @@ impl Wrapper { pub fn incoming_edges(&mut self) -> Wrapper { self.0.write_or_panic().incoming_edges() } + + pub fn neighbors(&mut self, direction: EdgeDirection) -> Wrapper { + self.0.write_or_panic().neighbors(direction) + } + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +macro_rules! implement_value_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(NodeIndicesOperation::NodeIndexOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_value_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations + .push($operation::NodeIndexComparisonOperation { + operand: value.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: value.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_unary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name(&mut self) { + self.operations.push($operation::UnaryArithmeticOperation { + kind: UnaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $value_type:ty) => { + pub fn $name(&self, value: $value_type) { + self.0.write_or_panic().$name(value) + } + }; +} + +#[derive(Debug, Clone)] +pub enum NodeIndexComparisonOperand { + Operand(NodeIndexOperand), + Index(NodeIndex), +} + +impl DeepClone for NodeIndexComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Index(value) => Self::Index(value.clone()), + } + } +} + +impl From> for NodeIndexComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for NodeIndexComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for NodeIndexComparisonOperand { + fn from(value: V) -> Self { + Self::Index(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum NodeIndicesComparisonOperand { + Operand(NodeIndicesOperand), + Indices(Vec), +} + +impl DeepClone for NodeIndicesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Indices(value) => Self::Indices(value.clone()), + } + } +} + +impl From> for NodeIndicesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for NodeIndicesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for NodeIndicesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Indices(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> for NodeIndicesComparisonOperand { + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct NodeIndicesOperand { + pub(crate) context: NodeOperand, + operations: Vec, +} + +impl DeepClone for NodeIndicesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl NodeIndicesOperand { + pub(crate) fn new(context: NodeOperand) -> Self { + Self { + context, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + let values = Box::new(values) as BoxedIterator; + + self.operations + .iter() + .try_fold(values, |value_tuples, operation| { + operation.evaluate(medrecord, value_tuples) + }) + } + + implement_value_operation!(max, Max); + implement_value_operation!(min, Min); + implement_value_operation!(count, Count); + implement_value_operation!(sum, Sum); + implement_value_operation!(first, First); + implement_value_operation!(last, Last); + + implement_single_value_comparison_operation!(greater_than, NodeIndicesOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + NodeIndicesOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, NodeIndicesOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + NodeIndicesOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, NodeIndicesOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, NodeIndicesOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, NodeIndicesOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, NodeIndicesOperation, EndsWith); + implement_single_value_comparison_operation!(contains, NodeIndicesOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(NodeIndicesOperation::NodeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(NodeIndicesOperation::NodeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, NodeIndicesOperation, Add); + implement_binary_arithmetic_operation!(sub, NodeIndicesOperation, Sub); + implement_binary_arithmetic_operation!(mul, NodeIndicesOperation, Mul); + implement_binary_arithmetic_operation!(pow, NodeIndicesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, NodeIndicesOperation, Mod); + + implement_unary_arithmetic_operation!(abs, NodeIndicesOperation, Abs); + implement_unary_arithmetic_operation!(trim, NodeIndicesOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, NodeIndicesOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, NodeIndicesOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, NodeIndicesOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, NodeIndicesOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(NodeIndicesOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, NodeIndicesOperation::IsString); + implement_assertion_operation!(is_int, NodeIndicesOperation::IsInt); + implement_assertion_operation!(is_max, NodeIndicesOperation::IsMax); + implement_assertion_operation!(is_min, NodeIndicesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(self.context.clone()); + let mut or_operand = Wrapper::::new(self.context.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(NodeIndicesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: NodeOperand) -> Self { + NodeIndicesOperand::new(context).into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + self.0.read_or_panic().evaluate(medrecord, values) + } + + implement_wrapper_operand_with_return!(max, NodeIndexOperand); + implement_wrapper_operand_with_return!(min, NodeIndexOperand); + implement_wrapper_operand_with_return!(count, NodeIndexOperand); + implement_wrapper_operand_with_return!(sum, NodeIndexOperand); + implement_wrapper_operand_with_return!(first, NodeIndexOperand); + implement_wrapper_operand_with_return!(last, NodeIndexOperand); + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct NodeIndexOperand { + pub(crate) context: NodeIndicesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for NodeIndexOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl NodeIndexOperand { + pub(crate) fn new(context: NodeIndicesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: NodeIndex, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(value), |value, operation| { + if let Some(value) = value { + operation.evaluate(medrecord, value) + } else { + Ok(None) + } + }) + } + + implement_single_value_comparison_operation!(greater_than, NodeIndexOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + NodeIndexOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, NodeIndexOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + NodeIndexOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, NodeIndexOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, NodeIndexOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, NodeIndexOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, NodeIndexOperation, EndsWith); + implement_single_value_comparison_operation!(contains, NodeIndexOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(NodeIndexOperation::NodeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(NodeIndexOperation::NodeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, NodeIndexOperation, Add); + implement_binary_arithmetic_operation!(sub, NodeIndexOperation, Sub); + implement_binary_arithmetic_operation!(mul, NodeIndexOperation, Mul); + implement_binary_arithmetic_operation!(pow, NodeIndexOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, NodeIndexOperation, Mod); + + implement_unary_arithmetic_operation!(abs, NodeIndexOperation, Abs); + implement_unary_arithmetic_operation!(trim, NodeIndexOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, NodeIndexOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, NodeIndexOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, NodeIndexOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, NodeIndexOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations.push(NodeIndexOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, NodeIndexOperation::IsString); + implement_assertion_operation!(is_int, NodeIndexOperation::IsInt); + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(NodeIndexOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: NodeIndicesOperand, kind: SingleKind) -> Self { + NodeIndexOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: NodeIndex, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, value) + } + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } } diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs b/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs index 99606b2f..90e6692c 100644 --- a/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs +++ b/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs @@ -1,21 +1,54 @@ +use super::{ + operand::{ + NodeIndexComparisonOperand, NodeIndexOperand, NodeIndicesComparisonOperand, + NodeIndicesOperand, + }, + BinaryArithmeticKind, MultipleComparisonKind, NodeOperand, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; use crate::{ - errors::MedRecordResult, + errors::{MedRecordError, MedRecordResult}, medrecord::{ + datatypes::{ + Abs, Contains, EndsWith, Lowercase, Mod, Pow, Slice, StartsWith, Trim, TrimEnd, + TrimStart, Uppercase, + }, querying::{ + attributes::AttributesTreeOperand, edges::EdgeOperand, traits::{DeepClone, ReadWriteOrPanic}, - values::MedRecordValuesOperand, + values::MultipleValuesOperand, wrapper::{CardinalityWrapper, Wrapper}, + BoxedIterator, }, - Group, MedRecord, MedRecordAttribute, MedRecordValue, NodeIndex, + DataType, Group, MedRecord, MedRecordAttribute, MedRecordValue, NodeIndex, }, }; +use itertools::Itertools; use roaring::RoaringBitmap; +use std::{ + cmp::Ordering, + collections::HashSet, + ops::{Add, Mul, Range, Sub}, +}; + +#[derive(Debug, Clone)] +pub enum EdgeDirection { + Incoming, + Outgoing, + Both, +} #[derive(Debug, Clone)] pub enum NodeOperation { - Attribute { - operand: Wrapper, + Values { + operand: Wrapper, + }, + Attributes { + operand: Wrapper, + }, + Indices { + operand: Wrapper, }, InGroup { @@ -31,12 +64,28 @@ pub enum NodeOperation { IncomingEdges { operand: Wrapper, }, + + Neighbors { + operand: Wrapper, + direction: EdgeDirection, + }, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, } impl DeepClone for NodeOperation { fn deep_clone(&self) -> Self { match self { - Self::Attribute { operand } => Self::Attribute { + Self::Values { operand } => Self::Values { + operand: operand.deep_clone(), + }, + Self::Attributes { operand } => Self::Attributes { + operand: operand.deep_clone(), + }, + Self::Indices { operand } => Self::Indices { operand: operand.deep_clone(), }, Self::InGroup { group } => Self::InGroup { @@ -51,6 +100,17 @@ impl DeepClone for NodeOperation { Self::IncomingEdges { operand } => Self::IncomingEdges { operand: operand.deep_clone(), }, + Self::Neighbors { + operand, + direction: drection, + } => Self::Neighbors { + operand: operand.deep_clone(), + direction: drection.clone(), + }, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, } } } @@ -60,9 +120,19 @@ impl NodeOperation { &self, medrecord: &'a MedRecord, node_indices: impl Iterator + 'a, - ) -> MedRecordResult + 'a>> { + ) -> MedRecordResult> { Ok(match self { - Self::Attribute { operand } => Box::new(Self::evaluate_attribute( + Self::Values { operand } => Box::new(Self::evaluate_values( + medrecord, + node_indices, + operand.clone(), + )?), + Self::Attributes { operand } => Box::new(Self::evaluate_attributes( + medrecord, + node_indices, + operand.clone(), + )?), + Self::Indices { operand } => Box::new(Self::evaluate_indices( medrecord, node_indices, operand.clone(), @@ -87,31 +157,48 @@ impl NodeOperation { node_indices, operand.clone(), )?), + Self::Neighbors { + operand, + direction: drection, + } => Box::new(Self::evaluate_neighbors( + medrecord, + node_indices, + operand.clone(), + drection.clone(), + )?), + Self::EitherOr { either, or } => { + // TODO: This is a temporary solution. It should be optimized. + let either_result = either.evaluate(medrecord)?.collect::>(); + let or_result = or.evaluate(medrecord)?.collect::>(); + + Box::new(either_result.into_iter().chain(or_result).unique()) + } }) } #[inline] pub(crate) fn get_values<'a>( medrecord: &'a MedRecord, - node_indices: impl Iterator + 'a, + node_indices: impl Iterator, attribute: MedRecordAttribute, - ) -> impl Iterator + 'a { + ) -> impl Iterator { node_indices.flat_map(move |node_index| { Some(( node_index, medrecord .node_attributes(node_index) .expect("Edge must exist") - .get(&attribute)?, + .get(&attribute)? + .clone(), )) }) } #[inline] - fn evaluate_attribute<'a>( + fn evaluate_values<'a>( medrecord: &'a MedRecord, node_indices: impl Iterator + 'a, - operand: Wrapper, + operand: Wrapper, ) -> MedRecordResult> { let values = Self::get_values( medrecord, @@ -123,11 +210,58 @@ impl NodeOperation { } #[inline] - fn evaluate_in_group<'a>( + pub(crate) fn get_attributes<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + ) -> impl Iterator)> { + node_indices.map(move |node_index| { + let attributes = medrecord + .node_attributes(node_index) + .expect("Edge must exist") + .keys() + .cloned(); + + (node_index, attributes.collect()) + }) + } + + #[inline] + fn evaluate_attributes<'a>( medrecord: &'a MedRecord, node_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let attributes = Self::get_attributes(medrecord, node_indices); + + Ok(operand + .evaluate(medrecord, attributes)? + .map(|value| value.0)) + } + + #[inline] + fn evaluate_indices<'a>( + medrecord: &MedRecord, + edge_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + // TODO: This is a temporary solution. It should be optimized. + let node_indices = edge_indices.collect::>(); + + let result = operand + .evaluate(medrecord, node_indices.clone().into_iter().cloned())? + .collect::>(); + + Ok(node_indices + .into_iter() + .filter(move |index| result.contains(index))) + } + + #[inline] + fn evaluate_in_group<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, group: CardinalityWrapper, - ) -> impl Iterator + 'a { + ) -> impl Iterator { node_indices.filter(move |node_index| { let groups_of_node = medrecord .groups_of_node(node_index) @@ -147,9 +281,9 @@ impl NodeOperation { #[inline] fn evaluate_has_attribute<'a>( medrecord: &'a MedRecord, - node_indices: impl Iterator + 'a, + node_indices: impl Iterator, attribute: CardinalityWrapper, - ) -> impl Iterator + 'a { + ) -> impl Iterator { node_indices.filter(move |node_index| { let attributes_of_node = medrecord .node_attributes(node_index) @@ -170,9 +304,9 @@ impl NodeOperation { #[inline] fn evaluate_outgoing_edges<'a>( medrecord: &'a MedRecord, - node_indices: impl Iterator + 'a, + node_indices: impl Iterator, operand: Wrapper, - ) -> MedRecordResult + 'a> { + ) -> MedRecordResult> { let edge_indices = operand.evaluate(medrecord)?.collect::(); Ok(node_indices.filter(move |node_index| { @@ -189,9 +323,9 @@ impl NodeOperation { #[inline] fn evaluate_incoming_edges<'a>( medrecord: &'a MedRecord, - node_indices: impl Iterator + 'a, + node_indices: impl Iterator, operand: Wrapper, - ) -> MedRecordResult + 'a> { + ) -> MedRecordResult> { let edge_indices = operand.evaluate(medrecord)?.collect::(); Ok(node_indices.filter(move |node_index| { @@ -204,4 +338,634 @@ impl NodeOperation { !incoming_edge_indices.is_disjoint(&edge_indices) })) } + + #[inline] + fn evaluate_neighbors<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + operand: Wrapper, + direction: EdgeDirection, + ) -> MedRecordResult> { + let result = operand.evaluate(medrecord)?.collect::>(); + + Ok(node_indices.filter(move |node_index| { + let mut neighbors: Box> = match direction { + EdgeDirection::Incoming => Box::new( + medrecord + .neighbors_incoming(node_index) + .expect("Node must exist"), + ), + EdgeDirection::Outgoing => Box::new( + medrecord + .neighbors_outgoing(node_index) + .expect("Node must exist"), + ), + EdgeDirection::Both => Box::new( + medrecord + .neighbors_undirected(node_index) + .expect("Node must exist"), + ), + }; + + neighbors.any(|neighbor| result.contains(&neighbor)) + })) + } +} + +macro_rules! get_node_index { + ($kind:ident, $indices:expr) => { + match $kind { + SingleKind::Max => NodeIndicesOperation::get_max($indices)?.clone(), + SingleKind::Min => NodeIndicesOperation::get_min($indices)?.clone(), + SingleKind::Count => NodeIndicesOperation::get_count($indices), + SingleKind::Sum => NodeIndicesOperation::get_sum($indices)?, + SingleKind::First => NodeIndicesOperation::get_first($indices)?, + SingleKind::Last => NodeIndicesOperation::get_last($indices)?, + } + }; +} + +macro_rules! get_node_index_comparison_operand { + ($operand:ident, $medrecord:ident) => { + match $operand { + NodeIndexComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + // TODO: This is a temporary solution. It should be optimized. + let comparison_indices = context.evaluate($medrecord)?.cloned(); + + let comparison_index = get_node_index!(kind, comparison_indices); + + comparison_index + } + NodeIndexComparisonOperand::Index(index) => index.clone(), + } + }; +} + +#[derive(Debug, Clone)] +pub enum NodeIndicesOperation { + NodeIndexOperation { + operand: Wrapper, + }, + NodeIndexComparisonOperation { + operand: NodeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + NodeIndicesComparisonOperation { + operand: NodeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: NodeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for NodeIndicesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::NodeIndexOperation { operand } => Self::NodeIndexOperation { + operand: operand.deep_clone(), + }, + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::NodeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::NodeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl NodeIndicesOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + indices: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::NodeIndexOperation { operand } => { + Self::evaluate_node_index_operation(medrecord, indices, operand) + } + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::evaluate_node_index_comparison_operation(medrecord, indices, operand, kind) + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_node_indices_comparison_operation(medrecord, indices, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Ok(Box::new(Self::evaluate_binary_arithmetic_operation( + medrecord, + indices, + operand, + kind.clone(), + )?)) + } + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(indices, kind.clone()), + )), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(indices, range.clone()))), + Self::IsString => { + Ok(Box::new(indices.filter(|index| { + matches!(index, MedRecordAttribute::String(_)) + }))) + } + Self::IsInt => { + Ok(Box::new(indices.filter(|index| { + matches!(index, MedRecordAttribute::Int(_)) + }))) + } + Self::IsMax => { + let max_index = Self::get_max(indices)?; + + Ok(Box::new(std::iter::once(max_index))) + } + Self::IsMin => { + let min_index = Self::get_min(indices)?; + + Ok(Box::new(std::iter::once(min_index))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, indices, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max( + mut indices: impl Iterator, + ) -> MedRecordResult { + let max_index = indices.next().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + ))?; + + indices.try_fold(max_index, |max_index, index| { + match index + .partial_cmp(&max_index) { + Some(Ordering::Greater) => Ok(index), + None => { + let first_dtype = DataType::from(index); + let second_dtype = DataType::from(max_index); + + Err(MedRecordError::QueryError(format!( + "Cannot compare indices of data types {} and {}. Consider narrowing down the indices using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max_index), + } + }) + } + + #[inline] + pub(crate) fn get_min( + mut indices: impl Iterator, + ) -> MedRecordResult { + let min_index = indices.next().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + ))?; + + indices.try_fold(min_index, |min_index, index| { + match index.partial_cmp(&min_index) { + Some(Ordering::Less) => Ok(index), + None => { + let first_dtype = DataType::from(index); + let second_dtype = DataType::from(min_index); + + Err(MedRecordError::QueryError(format!( + "Cannot compare indices of data types {} and {}. Consider narrowing down the indices using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(min_index), + } + }) + } + #[inline] + pub(crate) fn get_count(indices: impl Iterator) -> NodeIndex { + MedRecordAttribute::Int(indices.count() as i64) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_sum( + mut indices: impl Iterator, + ) -> MedRecordResult { + let first_value = indices + .next() + .ok_or(MedRecordError::QueryError("No indices to sum".to_string()))?; + + indices.try_fold(first_value, |sum, index| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&index); + + sum.add(index).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add indices of data types {} and {}. Consider narrowing down the indices using .is_string() or .is_int()", + first_dtype, second_dtype + )) + }) + }) + } + + #[inline] + pub(crate) fn get_first( + mut indices: impl Iterator, + ) -> MedRecordResult { + indices.next().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + pub(crate) fn get_last(indices: impl Iterator) -> MedRecordResult { + indices.last().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + fn evaluate_node_index_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let indices = indices.collect::>(); + + let index = get_node_index!(kind, indices.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, index)? { + Some(_) => Box::new(indices.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_node_index_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &NodeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = get_node_index_comparison_operand!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => Ok(Box::new( + indices.filter(move |index| index > &comparison_index), + )), + SingleComparisonKind::GreaterThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index >= &comparison_index), + )), + SingleComparisonKind::LessThan => Ok(Box::new( + indices.filter(move |index| index < &comparison_index), + )), + SingleComparisonKind::LessThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index <= &comparison_index), + )), + SingleComparisonKind::EqualTo => Ok(Box::new( + indices.filter(move |index| index == &comparison_index), + )), + SingleComparisonKind::NotEqualTo => Ok(Box::new( + indices.filter(move |index| index != &comparison_index), + )), + SingleComparisonKind::StartsWith => Ok(Box::new( + indices.filter(move |index| index.starts_with(&comparison_index)), + )), + SingleComparisonKind::EndsWith => Ok(Box::new( + indices.filter(move |index| index.ends_with(&comparison_index)), + )), + SingleComparisonKind::Contains => Ok(Box::new( + indices.filter(move |index| index.contains(&comparison_index)), + )), + } + } + + #[inline] + fn evaluate_node_indices_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &NodeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + NodeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + context.evaluate(medrecord)?.cloned().collect::>() + } + NodeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => Ok(Box::new( + indices.filter(move |index| comparison_indices.contains(index)), + )), + MultipleComparisonKind::IsNotIn => Ok(Box::new( + indices.filter(move |index| !comparison_indices.contains(index)), + )), + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &NodeIndexComparisonOperand, + kind: BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_node_index_comparison_operand!(operand, medrecord); + + let indices = indices + .map(move |index| { + match kind { + BinaryArithmeticKind::Add => index.add(arithmetic_index.clone()), + BinaryArithmeticKind::Sub => index.sub(arithmetic_index.clone()), + BinaryArithmeticKind::Mul => { + index.clone().mul(arithmetic_index.clone()) + } + BinaryArithmeticKind::Pow => { + index.clone().pow(arithmetic_index.clone()) + } + BinaryArithmeticKind::Mod => { + index.clone().r#mod(arithmetic_index.clone()) + } + } + .map_err(|_| { + MedRecordError::QueryError(format!( + "Failed arithmetic operation {}. Consider narrowing down the indices using .is_string() or .is_int()", + kind, + )) + }) + }); + + // TODO: This is a temporary solution. It should be optimized. + Ok(indices.collect::>>()?.into_iter()) + } + + #[inline] + fn evaluate_unary_arithmetic_operation( + indices: impl Iterator, + kind: UnaryArithmeticKind, + ) -> impl Iterator { + indices.map(move |index| match kind { + UnaryArithmeticKind::Abs => index.abs(), + UnaryArithmeticKind::Trim => index.trim(), + UnaryArithmeticKind::TrimStart => index.trim_start(), + UnaryArithmeticKind::TrimEnd => index.trim_end(), + UnaryArithmeticKind::Lowercase => index.lowercase(), + UnaryArithmeticKind::Uppercase => index.uppercase(), + }) + } + + #[inline] + fn evaluate_slice( + indices: impl Iterator, + range: Range, + ) -> impl Iterator { + indices.map(move |index| index.slice(range.clone())) + } + + #[inline] + fn evaluate_either_or<'a>( + medrecord: &'a MedRecord, + indices: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let indices = indices.collect::>(); + + let either_indices = either.evaluate(medrecord, indices.clone().into_iter())?; + let or_indices = or.evaluate(medrecord, indices.into_iter())?; + + Ok(Box::new(either_indices.chain(or_indices).unique())) + } +} + +#[derive(Debug, Clone)] +pub enum NodeIndexOperation { + NodeIndexComparisonOperation { + operand: NodeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + NodeIndicesComparisonOperation { + operand: NodeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: NodeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for NodeIndexOperation { + fn deep_clone(&self) -> Self { + match self { + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::NodeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::NodeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl NodeIndexOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + index: NodeIndex, + ) -> MedRecordResult> { + match self { + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::evaluate_node_index_comparison_operation(medrecord, index, operand, kind) + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_node_indices_comparison_operation(medrecord, index, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, index, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Some(match kind { + UnaryArithmeticKind::Abs => index.abs(), + UnaryArithmeticKind::Trim => index.trim(), + UnaryArithmeticKind::TrimStart => index.trim_start(), + UnaryArithmeticKind::TrimEnd => index.trim_end(), + UnaryArithmeticKind::Lowercase => index.lowercase(), + UnaryArithmeticKind::Uppercase => index.uppercase(), + })), + Self::Slice(range) => Ok(Some(index.slice(range.clone()))), + Self::IsString => Ok(match index { + MedRecordAttribute::String(_) => Some(index), + _ => None, + }), + Self::IsInt => Ok(match index { + MedRecordAttribute::Int(_) => Some(index), + _ => None, + }), + Self::EitherOr { either, or } => Self::evaluate_either_or(medrecord, index, either, or), + } + } + + #[inline] + fn evaluate_node_index_comparison_operation( + medrecord: &MedRecord, + index: NodeIndex, + comparison_operand: &NodeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = get_node_index_comparison_operand!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => index > comparison_index, + SingleComparisonKind::GreaterThanOrEqualTo => index >= comparison_index, + SingleComparisonKind::LessThan => index < comparison_index, + SingleComparisonKind::LessThanOrEqualTo => index <= comparison_index, + SingleComparisonKind::EqualTo => index == comparison_index, + SingleComparisonKind::NotEqualTo => index != comparison_index, + SingleComparisonKind::StartsWith => index.starts_with(&comparison_index), + SingleComparisonKind::EndsWith => index.ends_with(&comparison_index), + SingleComparisonKind::Contains => index.contains(&comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_node_indices_comparison_operation( + medrecord: &MedRecord, + index: NodeIndex, + comparison_operand: &NodeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + NodeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + context.evaluate(medrecord)?.cloned().collect::>() + } + NodeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_indices + .into_iter() + .any(|comparison_index| index == comparison_index), + MultipleComparisonKind::IsNotIn => comparison_indices + .into_iter() + .all(|comparison_index| index != comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + index: NodeIndex, + operand: &NodeIndexComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_node_index_comparison_operand!(operand, medrecord); + + Ok(Some(match kind { + BinaryArithmeticKind::Add => index.add(arithmetic_index)?, + BinaryArithmeticKind::Sub => index.sub(arithmetic_index)?, + BinaryArithmeticKind::Mul => index.mul(arithmetic_index)?, + BinaryArithmeticKind::Pow => index.pow(arithmetic_index)?, + BinaryArithmeticKind::Mod => index.r#mod(arithmetic_index)?, + })) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + index: NodeIndex, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, index.clone())?; + let or_result = or.evaluate(medrecord, index)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } } diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/values.rs b/crates/medmodels-core/src/medrecord/querying/nodes/values.rs deleted file mode 100644 index 8b137891..00000000 --- a/crates/medmodels-core/src/medrecord/querying/nodes/values.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/crates/medmodels-core/src/medrecord/querying/traits.rs b/crates/medmodels-core/src/medrecord/querying/traits.rs index 3a99dfde..4e8d33e8 100644 --- a/crates/medmodels-core/src/medrecord/querying/traits.rs +++ b/crates/medmodels-core/src/medrecord/querying/traits.rs @@ -1,6 +1,6 @@ use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; -pub(crate) trait DeepClone { +pub trait DeepClone { fn deep_clone(&self) -> Self; } diff --git a/crates/medmodels-core/src/medrecord/querying/values/mod.rs b/crates/medmodels-core/src/medrecord/querying/values/mod.rs index 1062cd5d..bf2e2f4a 100644 --- a/crates/medmodels-core/src/medrecord/querying/values/mod.rs +++ b/crates/medmodels-core/src/medrecord/querying/values/mod.rs @@ -1,4 +1,185 @@ mod operand; mod operation; -pub use operand::{Context, MedRecordValuesOperand}; +use super::{ + attributes::{ + self, AttributesTreeOperation, MultipleAttributesOperand, MultipleAttributesOperation, + }, + edges::{EdgeOperand, EdgeOperation}, + nodes::{NodeOperand, NodeOperation}, + BoxedIterator, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{MedRecordAttribute, MedRecordValue}, + MedRecord, +}; +pub use operand::MultipleValuesOperand; +use std::fmt::Display; + +macro_rules! get_attributes { + ($operand:ident, $medrecord:ident, $operation:ident, $multiple_attributes_operand:ident) => {{ + let indices = $operand.evaluate($medrecord)?; + + let attributes = $operation::get_attributes($medrecord, indices); + + let attributes = $multiple_attributes_operand + .context + .evaluate($medrecord, attributes)?; + + let attributes: Box> = + match $multiple_attributes_operand.kind { + attributes::MultipleKind::Max => { + Box::new(AttributesTreeOperation::get_max(attributes)?) + } + attributes::MultipleKind::Min => { + Box::new(AttributesTreeOperation::get_min(attributes)?) + } + attributes::MultipleKind::Count => { + Box::new(AttributesTreeOperation::get_count(attributes)?) + } + attributes::MultipleKind::Sum => { + Box::new(AttributesTreeOperation::get_sum(attributes)?) + } + attributes::MultipleKind::First => { + Box::new(AttributesTreeOperation::get_first(attributes)?) + } + attributes::MultipleKind::Last => { + Box::new(AttributesTreeOperation::get_last(attributes)?) + } + }; + + let attributes = $multiple_attributes_operand.evaluate($medrecord, attributes)?; + + Box::new( + MultipleAttributesOperation::get_values($medrecord, attributes)? + .map(|(_, value)| value), + ) + }}; +} + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Mean, + Median, + Mode, + Std, + Var, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Div, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Div => write!(f, "div"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} + +#[derive(Debug, Clone)] +pub enum UnaryArithmeticKind { + Round, + Ceil, + Floor, + Abs, + Sqrt, + Trim, + TrimStart, + TrimEnd, + Lowercase, + Uppercase, +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone)] +pub enum Context { + NodeOperand(NodeOperand), + EdgeOperand(EdgeOperand), + MultipleAttributesOperand(MultipleAttributesOperand), +} + +impl Context { + pub(crate) fn get_values<'a>( + &self, + medrecord: &'a MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + Ok(match self { + Self::NodeOperand(node_operand) => { + let node_indices = node_operand.evaluate(medrecord)?; + + Box::new( + NodeOperation::get_values(medrecord, node_indices, attribute) + .map(|(_, value)| value), + ) + } + Self::EdgeOperand(edge_operand) => { + let edge_indices = edge_operand.evaluate(medrecord)?; + + Box::new( + EdgeOperation::get_values(medrecord, edge_indices, attribute) + .map(|(_, value)| value), + ) + } + Self::MultipleAttributesOperand(multiple_attributes_operand) => { + match &multiple_attributes_operand.context.context { + attributes::Context::NodeOperand(node_operand) => { + get_attributes!( + node_operand, + medrecord, + NodeOperation, + multiple_attributes_operand + ) + } + attributes::Context::EdgeOperand(edge_operand) => { + get_attributes!( + edge_operand, + medrecord, + EdgeOperation, + multiple_attributes_operand + ) + } + } + } + }) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/values/operand.rs b/crates/medmodels-core/src/medrecord/querying/values/operand.rs index 12a74e46..01796ed7 100644 --- a/crates/medmodels-core/src/medrecord/querying/values/operand.rs +++ b/crates/medmodels-core/src/medrecord/querying/values/operand.rs @@ -1,74 +1,170 @@ -use super::operation::{MedRecordValueOperation, MedRecordValuesOperation}; +use super::{ + operation::{MultipleValuesOperation, SingleValueOperation}, + BinaryArithmeticKind, Context, MultipleComparisonKind, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; use crate::{ errors::MedRecordResult, medrecord::{ querying::{ - edges::EdgeOperation, - nodes::NodeOperation, traits::{DeepClone, ReadWriteOrPanic}, + BoxedIterator, }, - EdgeOperand, MedRecordAttribute, MedRecordValue, NodeOperand, Wrapper, + MedRecordAttribute, MedRecordValue, Wrapper, }, MedRecord, }; +use std::hash::Hash; + +macro_rules! implement_value_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = + Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(MultipleValuesOperation::ValueOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_value_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations + .push($operation::SingleValueComparisonOperation { + operand: value.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: value.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_unary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name(&mut self) { + self.operations.push($operation::UnaryArithmeticOperation { + kind: UnaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $value_type:ty) => { + pub fn $name(&self, value: $value_type) { + self.0.write_or_panic().$name(value) + } + }; +} #[derive(Debug, Clone)] -pub enum MedRecordValueComparisonOperand { - SingleOperand(MedRecordValueOperand), - SingleValue(MedRecordValue), - MultipleOperand(MedRecordValuesOperand), - MultipleValues(Vec), +pub enum SingleValueComparisonOperand { + Operand(SingleValueOperand), + Value(MedRecordValue), } -impl DeepClone for MedRecordValueComparisonOperand { +impl DeepClone for SingleValueComparisonOperand { fn deep_clone(&self) -> Self { match self { - Self::SingleOperand(operand) => Self::SingleOperand(operand.deep_clone()), - Self::SingleValue(value) => Self::SingleValue(value.clone()), - Self::MultipleOperand(operand) => Self::MultipleOperand(operand.deep_clone()), - Self::MultipleValues(values) => Self::MultipleValues(values.clone()), + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Value(value) => Self::Value(value.clone()), } } } -impl From> for MedRecordValueComparisonOperand { - fn from(value: Wrapper) -> Self { - Self::SingleOperand(value.0.read_or_panic().deep_clone()) +impl From> for SingleValueComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) } } -impl From<&Wrapper> for MedRecordValueComparisonOperand { - fn from(value: &Wrapper) -> Self { - Self::SingleOperand(value.0.read_or_panic().deep_clone()) +impl From<&Wrapper> for SingleValueComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) } } -impl> From for MedRecordValueComparisonOperand { +impl> From for SingleValueComparisonOperand { fn from(value: V) -> Self { - Self::SingleValue(value.into()) + Self::Value(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum MultipleValuesComparisonOperand { + Operand(MultipleValuesOperand), + Values(Vec), +} + +impl DeepClone for MultipleValuesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Values(value) => Self::Values(value.clone()), + } } } -impl From> for MedRecordValueComparisonOperand { - fn from(value: Wrapper) -> Self { - Self::MultipleOperand(value.0.read_or_panic().deep_clone()) +impl From> for MultipleValuesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) } } -impl From<&Wrapper> for MedRecordValueComparisonOperand { - fn from(value: &Wrapper) -> Self { - Self::MultipleOperand(value.0.read_or_panic().deep_clone()) +impl From<&Wrapper> for MultipleValuesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) } } -impl> From> for MedRecordValueComparisonOperand { +impl> From> for MultipleValuesComparisonOperand { fn from(value: Vec) -> Self { - Self::MultipleValues(value.into_iter().map(Into::into).collect()) + Self::Values(value.into_iter().map(Into::into).collect()) } } impl + Clone, const N: usize> From<[V; N]> - for MedRecordValueComparisonOperand + for MultipleValuesComparisonOperand { fn from(value: [V; N]) -> Self { value.to_vec().into() @@ -76,66 +172,23 @@ impl + Clone, const N: usize> From<[V; N]> } #[derive(Debug, Clone)] -pub enum ValueKind { - Max, - Min, -} - -#[derive(Debug, Clone)] -pub enum Context { - NodeOperand(NodeOperand), - EdgeOperand(EdgeOperand), -} - -impl Context { - pub(crate) fn get_values<'a>( - &self, - medrecord: &'a MedRecord, - attribute: MedRecordAttribute, - ) -> MedRecordResult + 'a>> { - Ok(match self { - Self::NodeOperand(node_operand) => { - let node_indices = node_operand.evaluate(medrecord)?; - - Box::new( - NodeOperation::get_values(medrecord, node_indices, attribute) - .map(|(_, value)| value), - ) - } - Self::EdgeOperand(edge_operand) => { - let edge_indices = edge_operand.evaluate(medrecord)?; - - Box::new( - EdgeOperation::get_values(medrecord, edge_indices, attribute) - .map(|(_, value)| value), - ) - } - }) - } -} - -#[derive(Debug, Clone)] -pub struct MedRecordValuesOperand { +pub struct MultipleValuesOperand { pub(crate) context: Context, pub(crate) attribute: MedRecordAttribute, - operations: Vec, + operations: Vec, } -impl DeepClone for MedRecordValuesOperand { +impl DeepClone for MultipleValuesOperand { fn deep_clone(&self) -> Self { Self { context: self.context.clone(), attribute: self.attribute.clone(), - operations: self - .operations - .iter() - .map(|operation| operation.deep_clone()) - .collect(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), } } } -impl MedRecordValuesOperand { +impl MultipleValuesOperand { pub(crate) fn new(context: Context, attribute: MedRecordAttribute) -> Self { Self { context, @@ -144,98 +197,223 @@ impl MedRecordValuesOperand { } } - pub(crate) fn evaluate<'a, T: 'a>( + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash>( &self, medrecord: &'a MedRecord, - values: impl Iterator + 'a, - ) -> MedRecordResult> { - let values = Box::new(values) as Box>; + values: impl Iterator + 'a, + ) -> MedRecordResult> { + let values = Box::new(values) as BoxedIterator<(&'a T, MedRecordValue)>; self.operations .iter() - .try_fold(values, |edge_indices, operation| { - operation.evaluate(medrecord, edge_indices) + .try_fold(values, |value_tuples, operation| { + operation.evaluate(medrecord, value_tuples) }) } - pub fn max(&mut self) -> Wrapper { - let operand = Wrapper::::new(self.deep_clone(), ValueKind::Max); - + implement_value_operation!(max, Max); + implement_value_operation!(min, Min); + implement_value_operation!(mean, Mean); + implement_value_operation!(median, Median); + implement_value_operation!(mode, Mode); + implement_value_operation!(std, Std); + implement_value_operation!(var, Var); + implement_value_operation!(count, Count); + implement_value_operation!(sum, Sum); + implement_value_operation!(first, First); + implement_value_operation!(last, Last); + + implement_single_value_comparison_operation!( + greater_than, + MultipleValuesOperation, + GreaterThan + ); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + MultipleValuesOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, MultipleValuesOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + MultipleValuesOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, MultipleValuesOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, MultipleValuesOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, MultipleValuesOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, MultipleValuesOperation, EndsWith); + implement_single_value_comparison_operation!(contains, MultipleValuesOperation, Contains); + + pub fn is_in>(&mut self, values: V) { self.operations - .push(MedRecordValuesOperation::ValueOperand { - operand: operand.clone(), + .push(MultipleValuesOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, }); - - operand } - pub fn min(&mut self) -> Wrapper { - let operand = Wrapper::::new(self.deep_clone(), ValueKind::Min); - + pub fn is_not_in>(&mut self, values: V) { self.operations - .push(MedRecordValuesOperation::ValueOperand { - operand: operand.clone(), + .push(MultipleValuesOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, }); + } - operand + implement_binary_arithmetic_operation!(add, MultipleValuesOperation, Add); + implement_binary_arithmetic_operation!(sub, MultipleValuesOperation, Sub); + implement_binary_arithmetic_operation!(mul, MultipleValuesOperation, Mul); + implement_binary_arithmetic_operation!(div, MultipleValuesOperation, Div); + implement_binary_arithmetic_operation!(pow, MultipleValuesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, MultipleValuesOperation, Mod); + + implement_unary_arithmetic_operation!(round, MultipleValuesOperation, Round); + implement_unary_arithmetic_operation!(ceil, MultipleValuesOperation, Ceil); + implement_unary_arithmetic_operation!(floor, MultipleValuesOperation, Floor); + implement_unary_arithmetic_operation!(abs, MultipleValuesOperation, Abs); + implement_unary_arithmetic_operation!(sqrt, MultipleValuesOperation, Sqrt); + implement_unary_arithmetic_operation!(trim, MultipleValuesOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, MultipleValuesOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, MultipleValuesOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, MultipleValuesOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, MultipleValuesOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(MultipleValuesOperation::Slice(start..end)); } - pub fn less_than>(&mut self, value: V) { - self.operations.push(MedRecordValuesOperation::LessThan { - value: value.into(), + implement_assertion_operation!(is_string, MultipleValuesOperation::IsString); + implement_assertion_operation!(is_int, MultipleValuesOperation::IsInt); + implement_assertion_operation!(is_float, MultipleValuesOperation::IsFloat); + implement_assertion_operation!(is_bool, MultipleValuesOperation::IsBool); + implement_assertion_operation!(is_datetime, MultipleValuesOperation::IsDateTime); + implement_assertion_operation!(is_null, MultipleValuesOperation::IsNull); + implement_assertion_operation!(is_max, MultipleValuesOperation::IsMax); + implement_assertion_operation!(is_min, MultipleValuesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.attribute.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.attribute.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(MultipleValuesOperation::EitherOr { + either: either_operand, + or: or_operand, }); } } -impl Wrapper { +impl Wrapper { pub(crate) fn new(context: Context, attribute: MedRecordAttribute) -> Self { - MedRecordValuesOperand::new(context, attribute).into() + MultipleValuesOperand::new(context, attribute).into() } - pub(crate) fn evaluate<'a, T: 'a>( + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash>( &self, medrecord: &'a MedRecord, - values: impl Iterator + 'a, - ) -> MedRecordResult> { + values: impl Iterator + 'a, + ) -> MedRecordResult> { self.0.read_or_panic().evaluate(medrecord, values) } - pub fn max(&self) -> Wrapper { - self.0.write_or_panic().max() - } - - pub fn min(&self) -> Wrapper { - self.0.write_or_panic().min() + implement_wrapper_operand_with_return!(max, SingleValueOperand); + implement_wrapper_operand_with_return!(min, SingleValueOperand); + implement_wrapper_operand_with_return!(mean, SingleValueOperand); + implement_wrapper_operand_with_return!(median, SingleValueOperand); + implement_wrapper_operand_with_return!(mode, SingleValueOperand); + implement_wrapper_operand_with_return!(std, SingleValueOperand); + implement_wrapper_operand_with_return!(var, SingleValueOperand); + implement_wrapper_operand_with_return!(count, SingleValueOperand); + implement_wrapper_operand_with_return!(sum, SingleValueOperand); + implement_wrapper_operand_with_return!(first, SingleValueOperand); + implement_wrapper_operand_with_return!(last, SingleValueOperand); + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(div, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(round); + implement_wrapper_operand!(ceil); + implement_wrapper_operand!(floor); + implement_wrapper_operand!(abs); + implement_wrapper_operand!(sqrt); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) } - pub fn less_than>(&self, value: V) { - self.0.write_or_panic().less_than(value) + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_float); + implement_wrapper_operand!(is_bool); + implement_wrapper_operand!(is_datetime); + implement_wrapper_operand!(is_null); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); } } #[derive(Debug, Clone)] -pub struct MedRecordValueOperand { - pub(crate) context: MedRecordValuesOperand, - pub(crate) kind: ValueKind, - operations: Vec, +pub struct SingleValueOperand { + pub(crate) context: MultipleValuesOperand, + pub(crate) kind: SingleKind, + operations: Vec, } -impl DeepClone for MedRecordValueOperand { +impl DeepClone for SingleValueOperand { fn deep_clone(&self) -> Self { Self { context: self.context.deep_clone(), kind: self.kind.clone(), - operations: self - .operations - .iter() - .map(|operation| operation.deep_clone()) - .collect(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), } } } -impl MedRecordValueOperand { - pub(crate) fn new(context: MedRecordValuesOperand, kind: ValueKind) -> Self { +impl SingleValueOperand { + pub(crate) fn new(context: MultipleValuesOperand, kind: SingleKind) -> Self { Self { context, kind, @@ -243,39 +421,170 @@ impl MedRecordValueOperand { } } - pub(crate) fn evaluate<'a>( + pub(crate) fn evaluate( &self, - medrecord: &'a MedRecord, - value: &'a MedRecordValue, - ) -> MedRecordResult { - for operation in &self.operations { - operation.evaluate(medrecord, value)?; - } + medrecord: &MedRecord, + value: MedRecordValue, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(value), |value, operation| { + if let Some(value) = value { + operation.evaluate(medrecord, value) + } else { + Ok(None) + } + }) + } + + implement_single_value_comparison_operation!(greater_than, SingleValueOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + SingleValueOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, SingleValueOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + SingleValueOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, SingleValueOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, SingleValueOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, SingleValueOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, SingleValueOperation, EndsWith); + implement_single_value_comparison_operation!(contains, SingleValueOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(SingleValueOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } - Ok(true) + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(SingleValueOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); } - pub fn less_than>(&mut self, value: V) { - self.operations.push(MedRecordValueOperation::LessThan { - value: value.into(), + implement_binary_arithmetic_operation!(add, SingleValueOperation, Add); + implement_binary_arithmetic_operation!(sub, SingleValueOperation, Sub); + implement_binary_arithmetic_operation!(mul, SingleValueOperation, Mul); + implement_binary_arithmetic_operation!(div, SingleValueOperation, Div); + implement_binary_arithmetic_operation!(pow, SingleValueOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, SingleValueOperation, Mod); + + implement_unary_arithmetic_operation!(round, SingleValueOperation, Round); + implement_unary_arithmetic_operation!(ceil, SingleValueOperation, Ceil); + implement_unary_arithmetic_operation!(floor, SingleValueOperation, Floor); + implement_unary_arithmetic_operation!(abs, SingleValueOperation, Abs); + implement_unary_arithmetic_operation!(sqrt, SingleValueOperation, Sqrt); + implement_unary_arithmetic_operation!(trim, SingleValueOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, SingleValueOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, SingleValueOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, SingleValueOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, SingleValueOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(SingleValueOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, SingleValueOperation::IsString); + implement_assertion_operation!(is_int, SingleValueOperation::IsInt); + implement_assertion_operation!(is_float, SingleValueOperation::IsFloat); + implement_assertion_operation!(is_bool, SingleValueOperation::IsBool); + implement_assertion_operation!(is_datetime, SingleValueOperation::IsDateTime); + implement_assertion_operation!(is_null, SingleValueOperation::IsNull); + + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(SingleValueOperation::EitherOr { + either: either_operand, + or: or_operand, }); } } -impl Wrapper { - pub(crate) fn new(context: MedRecordValuesOperand, kind: ValueKind) -> Self { - MedRecordValueOperand::new(context, kind).into() +impl Wrapper { + pub(crate) fn new(context: MultipleValuesOperand, kind: SingleKind) -> Self { + SingleValueOperand::new(context, kind).into() } - pub(crate) fn evaluate<'a>( + pub(crate) fn evaluate( &self, - medrecord: &'a MedRecord, - value: &'a MedRecordValue, - ) -> MedRecordResult { + medrecord: &MedRecord, + value: MedRecordValue, + ) -> MedRecordResult> { self.0.read_or_panic().evaluate(medrecord, value) } - pub fn less_than>(&self, value: V) { - self.0.write_or_panic().less_than(value) + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(div, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(round); + implement_wrapper_operand!(ceil); + implement_wrapper_operand!(floor); + implement_wrapper_operand!(abs); + implement_wrapper_operand!(sqrt); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_float); + implement_wrapper_operand!(is_bool); + implement_wrapper_operand!(is_datetime); + implement_wrapper_operand!(is_null); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); } } diff --git a/crates/medmodels-core/src/medrecord/querying/values/operation.rs b/crates/medmodels-core/src/medrecord/querying/values/operation.rs index 1ff87629..2a559d99 100644 --- a/crates/medmodels-core/src/medrecord/querying/values/operation.rs +++ b/crates/medmodels-core/src/medrecord/querying/values/operation.rs @@ -1,63 +1,256 @@ -use std::cmp::Ordering; - -use super::operand::{MedRecordValueComparisonOperand, MedRecordValueOperand, ValueKind}; +use super::{ + operand::{ + MultipleValuesComparisonOperand, MultipleValuesOperand, SingleValueComparisonOperand, + SingleValueOperand, + }, + BinaryArithmeticKind, MultipleComparisonKind, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; use crate::{ errors::{MedRecordError, MedRecordResult}, medrecord::{ - querying::traits::{DeepClone, ReadWriteOrPanic}, + datatypes::{ + Abs, Ceil, Contains, EndsWith, Floor, Lowercase, Mod, Pow, Round, Slice, Sqrt, + StartsWith, Trim, TrimEnd, TrimStart, Uppercase, + }, + querying::{ + traits::{DeepClone, ReadWriteOrPanic}, + BoxedIterator, + }, DataType, MedRecordValue, Wrapper, }, MedRecord, }; +use itertools::Itertools; +use std::{ + cmp::Ordering, + hash::Hash, + ops::{Add, Div, Mul, Range, Sub}, +}; + +macro_rules! get_single_operand_value { + ($kind:ident, $values:expr) => { + match $kind { + SingleKind::Max => MultipleValuesOperation::get_max($values)?.1, + SingleKind::Min => MultipleValuesOperation::get_min($values)?.1, + SingleKind::Mean => MultipleValuesOperation::get_mean($values)?, + SingleKind::Median => MultipleValuesOperation::get_median($values)?, + SingleKind::Mode => MultipleValuesOperation::get_mode($values)?, + SingleKind::Std => MultipleValuesOperation::get_std($values)?, + SingleKind::Var => MultipleValuesOperation::get_var($values)?, + SingleKind::Count => MultipleValuesOperation::get_count($values), + SingleKind::Sum => MultipleValuesOperation::get_sum($values)?, + SingleKind::First => MultipleValuesOperation::get_first($values)?, + SingleKind::Last => MultipleValuesOperation::get_last($values)?, + } + }; +} + +macro_rules! get_single_value_comparison_operand_value { + ($operand:ident, $medrecord:ident) => { + match $operand { + SingleValueComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let attribute = operand.context.attribute.clone(); + let kind = &operand.kind; + + let comparison_values = context + .get_values($medrecord, attribute)? + .map(|value| (&0, value)); + + let comparison_value = get_single_operand_value!(kind, comparison_values); + + comparison_value + } + SingleValueComparisonOperand::Value(value) => value.clone(), + } + }; +} + +macro_rules! get_median { + ($values:ident, $variant:ident) => { + if $values.len() % 2 == 0 { + let middle = $values.len() / 2; + + let first = $values.get(middle - 1).unwrap(); + let second = $values.get(middle).unwrap(); + + let first = MedRecordValue::$variant(*first); + let second = MedRecordValue::$variant(*second); + + first.add(second).unwrap().div(MedRecordValue::Int(2)) + } else { + let middle = $values.len() / 2; + + Ok(MedRecordValue::$variant( + $values.get(middle).unwrap().clone(), + )) + } + }; +} #[derive(Debug, Clone)] -pub enum MedRecordValuesOperation { - ValueOperand { - operand: Wrapper, +pub enum MultipleValuesOperation { + ValueOperation { + operand: Wrapper, + }, + SingleValueComparisonOperation { + operand: SingleValueComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleValuesComparisonOperation { + operand: MultipleValuesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleValueComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, }, - LessThan { - value: MedRecordValueComparisonOperand, + Slice(Range), + + IsString, + IsInt, + IsFloat, + IsBool, + IsDateTime, + IsNull, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, }, } -impl DeepClone for MedRecordValuesOperation { +impl DeepClone for MultipleValuesOperation { fn deep_clone(&self) -> Self { match self { - Self::ValueOperand { operand } => Self::ValueOperand { + Self::ValueOperation { operand } => Self::ValueOperation { + operand: operand.deep_clone(), + }, + Self::SingleValueComparisonOperation { operand, kind } => { + Self::SingleValueComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::MultipleValuesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { operand: operand.deep_clone(), + kind: kind.clone(), }, - Self::LessThan { value } => Self::LessThan { - value: value.deep_clone(), + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsFloat => Self::IsFloat, + Self::IsBool => Self::IsBool, + Self::IsDateTime => Self::IsDateTime, + Self::IsNull => Self::IsNull, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), }, } } } -impl MedRecordValuesOperation { - pub(crate) fn evaluate<'a, T: 'a>( +impl MultipleValuesOperation { + pub(crate) fn evaluate<'a, T: Eq + Hash>( &self, medrecord: &'a MedRecord, - values: impl Iterator + 'a, - ) -> MedRecordResult + 'a>> { + values: impl Iterator + 'a, + ) -> MedRecordResult> { match self { - Self::ValueOperand { operand } => { - Self::evaluate_value_operand(medrecord, values, operand) + Self::ValueOperation { operand } => { + Self::evaluate_value_operation(medrecord, values, operand) + } + Self::SingleValueComparisonOperation { operand, kind } => { + Self::evaluate_single_value_comparison_operation(medrecord, values, operand, kind) + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_values_comparison_operation( + medrecord, values, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => Ok(Box::new( + Self::evaluate_binary_arithmetic_operation(medrecord, values, operand, kind)?, + )), + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(values, kind.clone()), + )), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(values, range.clone()))), + Self::IsString => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::String(_)) + }))) + } + Self::IsInt => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Int(_)) + }))) + } + Self::IsFloat => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Float(_)) + }))) + } + Self::IsBool => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Bool(_)) + }))) + } + Self::IsDateTime => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::DateTime(_)) + }))) + } + Self::IsNull => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Null) + }))) + } + Self::IsMax => { + let max_value = Self::get_max(values)?; + + Ok(Box::new(std::iter::once(max_value))) + } + Self::IsMin => { + let min_value = Self::get_min(values)?; + + Ok(Box::new(std::iter::once(min_value))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, values, either, or) } - Self::LessThan { value } => Self::evaluate_less_than(medrecord, values, value.clone()), } } #[inline] - pub(crate) fn get_max<'a, T: 'a>( - mut values: impl Iterator, - ) -> MedRecordResult<(&'a T, &'a MedRecordValue)> { + pub(crate) fn get_max<'a, T>( + mut values: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordValue)> { let max_value = values.next().ok_or(MedRecordError::QueryError( "No values to compare".to_string(), ))?; values.try_fold(max_value, |max_value, value| { - match value.1.partial_cmp(max_value.1) { + match value.1.partial_cmp(&max_value.1) { Some(Ordering::Greater) => Ok(value), None => { let first_dtype = DataType::from(value.1); @@ -74,15 +267,15 @@ impl MedRecordValuesOperation { } #[inline] - pub(crate) fn get_min<'a, T: 'a>( - mut values: impl Iterator, - ) -> MedRecordResult<(&'a T, &'a MedRecordValue)> { + pub(crate) fn get_min<'a, T>( + mut values: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordValue)> { let min_value = values.next().ok_or(MedRecordError::QueryError( "No values to compare".to_string(), ))?; values.try_fold(min_value, |min_value, value| { - match value.1.partial_cmp(min_value.1) { + match value.1.partial_cmp(&min_value.1) { Some(Ordering::Less) => Ok(value), None => { let first_dtype = DataType::from(value.1); @@ -99,138 +292,643 @@ impl MedRecordValuesOperation { } #[inline] - fn evaluate_value_operand<'a, T: 'a>( + pub(crate) fn get_mean<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + let first_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + let (sum, count) = values.try_fold((first_value.1, 1), |(sum, count), (_, value)| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&value); + + match sum.add(value) { + Ok(sum) => Ok((sum, count + 1)), + Err(_) => Err(MedRecordError::QueryError(format!( + "Cannot add values of data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_dtype, second_dtype + ))), + } + })?; + + sum.div(MedRecordValue::Int(count as i64)) + } + + // TODO: This is a temporary solution. It should be optimized. + #[inline] + pub(crate) fn get_median<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + let first_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + let first_data_type = DataType::from(&first_value.1); + + match first_value.1 { + MedRecordValue::Int(value) => { + let mut values = values.map(|(_, value)| { + let data_type = DataType::from(&value); + + match value { + MedRecordValue::Int(value) => Ok(value as f64), + MedRecordValue::Float(value) => Ok(value), + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of mixed data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_data_type, data_type + ))), + } + }).collect::>>()?; + values.push(value as f64); + values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + get_median!(values, Float) + } + MedRecordValue::Float(value) => { + let mut values = values.map(|(_, value)| { + let data_type = DataType::from(&value); + + match value { + MedRecordValue::Int(value) => Ok(value as f64), + MedRecordValue::Float(value) => Ok(value), + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of mixed data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_data_type, data_type + ))), + } + }).collect::>>()?; + values.push(value); + values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + get_median!(values, Float) + } + MedRecordValue::DateTime(value) => { + let mut values = values.map(|(_, value)| { + let data_type = DataType::from(&value); + + match value { + MedRecordValue::DateTime(naive_date_time) => Ok(naive_date_time), + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of mixed data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_data_type, data_type + ))), + } + }).collect::>>()?; + values.push(value); + values.sort(); + + get_median!(values, DateTime) + } + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of data type {}", + first_data_type + )))?, + } + } + + // TODO: This is a temporary solution. It should be optimized. + #[inline] + pub(crate) fn get_mode<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + let values = values.map(|(_, value)| value).collect::>(); + + let most_common_value = values + .first() + .ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))? + .clone(); + let most_common_count = values + .iter() + .filter(|value| **value == most_common_value) + .count(); + + let (_, most_common_value) = values.clone().into_iter().fold( + (most_common_count, most_common_value), + |acc, value| { + let count = values.iter().filter(|v| **v == value).count(); + + if count > acc.0 { + (count, value) + } else { + acc + } + }, + ); + + Ok(most_common_value) + } + + #[inline] + // 👀 + pub(crate) fn get_std<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + let variance = Self::get_var(values)?; + + let MedRecordValue::Float(variance) = variance else { + unreachable!() + }; + + Ok(MedRecordValue::Float(variance.sqrt())) + } + + // TODO: This is a temporary solution. It should be optimized. + #[inline] + pub(crate) fn get_var<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + let values = values.collect::>(); + + let mean = Self::get_mean(values.clone().into_iter())?; + + let MedRecordValue::Float(mean) = mean else { + let data_type = DataType::from(mean); + + return Err(MedRecordError::QueryError( + format!("Cannot calculate variance of data type {}. Consider narrowing down the values using .is_int() or .is_float()", data_type), + )); + }; + + let values = values + .into_iter() + .map(|value| { + let data_type = DataType::from(&value.1); + + match value.1 { + MedRecordValue::Int(value) => Ok(value as f64), + MedRecordValue::Float(value) => Ok(value), + _ => Err(MedRecordError::QueryError( + format!("Cannot calculate variance of data type {}. Consider narrowing down the values using .is_int() or .is_float()", data_type), + )), + }}) + .collect::>>()?; + + let values_length = values.len(); + + let variance = values + .into_iter() + .map(|value| (value - mean).powi(2)) + .sum::() + / values_length as f64; + + Ok(MedRecordValue::Float(variance)) + } + + #[inline] + pub(crate) fn get_count<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordValue { + MedRecordValue::Int(values.count() as i64) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_sum<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + let first_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + values.try_fold(first_value.1, |sum, (_, value)| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&value); + + sum.add(value).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add values of data types {} and {}. Consider narrowing down the values using .is_string(), .is_int(), .is_float(), .is_bool() or .is_datetime()", + first_dtype, second_dtype + )) + }) + }) + } + + #[inline] + pub(crate) fn get_first<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + values + .next() + .ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + )) + .map(|(_, value)| value) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_last<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + values + .last() + .ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + )) + .map(|(_, value)| value) + } + + #[inline] + fn evaluate_value_operation<'a, T>( medrecord: &'a MedRecord, - values: impl Iterator, - operand: &Wrapper, - ) -> MedRecordResult + 'a>> { + values: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { let kind = &operand.0.read_or_panic().kind; - let value = match kind { - ValueKind::Max => Self::get_max(values), - ValueKind::Min => Self::get_min(values), - }?; + let values = values.collect::>(); + + let value = get_single_operand_value!(kind, values.clone().into_iter()); - Ok(match operand.evaluate(medrecord, value.1)? { - true => Box::new(std::iter::once(value)), - false => Box::new(std::iter::empty()), + Ok(match operand.evaluate(medrecord, value)? { + Some(_) => Box::new(values.into_iter()), + None => Box::new(std::iter::empty()), }) } #[inline] - fn evaluate_less_than<'a, T: 'a>( + fn evaluate_single_value_comparison_operation<'a, T>( medrecord: &'a MedRecord, - values: impl Iterator + 'a, - comparison: MedRecordValueComparisonOperand, - ) -> MedRecordResult + 'a>> { - match comparison { - MedRecordValueComparisonOperand::SingleOperand(comparison_operand) => { - let context = &comparison_operand.context.context; - let attribute = comparison_operand.context.attribute; - let kind = &comparison_operand.kind; + values: impl Iterator + 'a, + comparison_operand: &SingleValueComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_value = + get_single_value_comparison_operand_value!(comparison_operand, medrecord); - let comparison_values = context - .get_values(medrecord, attribute)? - .map(|value| (&0, value)); + match kind { + SingleComparisonKind::GreaterThan => Ok(Box::new( + values.filter(move |(_, value)| value > &comparison_value), + )), + SingleComparisonKind::GreaterThanOrEqualTo => Ok(Box::new( + values.filter(move |(_, value)| value >= &comparison_value), + )), + SingleComparisonKind::LessThan => Ok(Box::new( + values.filter(move |(_, value)| value < &comparison_value), + )), + SingleComparisonKind::LessThanOrEqualTo => Ok(Box::new( + values.filter(move |(_, value)| value <= &comparison_value), + )), + SingleComparisonKind::EqualTo => Ok(Box::new( + values.filter(move |(_, value)| value == &comparison_value), + )), + SingleComparisonKind::NotEqualTo => Ok(Box::new( + values.filter(move |(_, value)| value != &comparison_value), + )), + SingleComparisonKind::StartsWith => { + Ok(Box::new(values.filter(move |(_, value)| { + value.starts_with(&comparison_value) + }))) + } + SingleComparisonKind::EndsWith => { + Ok(Box::new(values.filter(move |(_, value)| { + value.ends_with(&comparison_value) + }))) + } + SingleComparisonKind::Contains => { + Ok(Box::new(values.filter(move |(_, value)| { + value.contains(&comparison_value) + }))) + } + } + } - let comparison_value = match kind { - ValueKind::Max => Self::get_max(comparison_values), - ValueKind::Min => Self::get_min(comparison_values), - }?; + #[inline] + fn evaluate_multiple_values_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + comparison_operand: &MultipleValuesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_values = match comparison_operand { + MultipleValuesComparisonOperand::Operand(operand) => { + let context = &operand.context; + let attribute = operand.attribute.clone(); - Ok(Box::new( - values.filter(|value| value.1 < comparison_value.1), - )) + context + .get_values(medrecord, attribute)? + .collect::>() } - MedRecordValueComparisonOperand::SingleValue(comparison_value) => Ok(Box::new( - values.filter(move |value| value.1 < &comparison_value), - )), - MedRecordValueComparisonOperand::MultipleOperand(comparison_operand) => { - let context = &comparison_operand.context; - let attribute = comparison_operand.attribute; - - let mut comparison_values = context.get_values(medrecord, attribute)?; + MultipleValuesComparisonOperand::Values(values) => values.clone(), + }; - Ok(Box::new(values.filter(move |value| { - comparison_values.all(|comparison_value| value.1 < comparison_value) + match kind { + MultipleComparisonKind::IsIn => { + Ok(Box::new(values.filter(move |(_, value)| { + comparison_values.contains(value) }))) } - MedRecordValueComparisonOperand::MultipleValues(comparison_values) => { - Ok(Box::new(values.filter(move |value| { - comparison_values - .iter() - .all(|comparison_value| value.1 < comparison_value) + MultipleComparisonKind::IsNotIn => { + Ok(Box::new(values.filter(move |(_, value)| { + !comparison_values.contains(value) }))) } } } + + #[inline] + fn evaluate_binary_arithmetic_operation<'a, T: 'a>( + medrecord: &'a MedRecord, + values: impl Iterator, + operand: &SingleValueComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_value = get_single_value_comparison_operand_value!(operand, medrecord); + + let values = values + .map(move |(t, value)| { + match kind { + BinaryArithmeticKind::Add => value.add(arithmetic_value.clone()), + BinaryArithmeticKind::Sub => value.sub(arithmetic_value.clone()), + BinaryArithmeticKind::Mul => { + value.clone().mul(arithmetic_value.clone()) + } + BinaryArithmeticKind::Div => { + value.clone().div(arithmetic_value.clone()) + } + BinaryArithmeticKind::Pow => { + value.clone().pow(arithmetic_value.clone()) + } + BinaryArithmeticKind::Mod => { + value.clone().r#mod(arithmetic_value.clone()) + } + } + .map_err(|_| { + MedRecordError::QueryError(format!( + "Failed arithmetic operation {}. Consider narrowing down the values using .is_int() or .is_float()", + kind, + )) + }).map(|result| (t, result)) + }); + + // TODO: This is a temporary solution. It should be optimized. + Ok(values.collect::>>()?.into_iter()) + } + + #[inline] + fn evaluate_unary_arithmetic_operation<'a, T: 'a>( + values: impl Iterator, + kind: UnaryArithmeticKind, + ) -> impl Iterator { + values.map(move |(t, value)| { + let value = match kind { + UnaryArithmeticKind::Round => value.round(), + UnaryArithmeticKind::Ceil => value.ceil(), + UnaryArithmeticKind::Floor => value.floor(), + UnaryArithmeticKind::Abs => value.abs(), + UnaryArithmeticKind::Sqrt => value.sqrt(), + UnaryArithmeticKind::Trim => value.trim(), + UnaryArithmeticKind::TrimStart => value.trim_start(), + UnaryArithmeticKind::TrimEnd => value.trim_end(), + UnaryArithmeticKind::Lowercase => value.lowercase(), + UnaryArithmeticKind::Uppercase => value.uppercase(), + }; + (t, value) + }) + } + + #[inline] + fn evaluate_slice<'a, T: 'a>( + values: impl Iterator, + range: Range, + ) -> impl Iterator { + values.map(move |(t, value)| (t, value.slice(range.clone()))) + } + + #[inline] + fn evaluate_either_or<'a, T: 'a + Eq + Hash>( + medrecord: &'a MedRecord, + values: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let values = values.collect::>(); + + let either_values = either.evaluate(medrecord, values.clone().into_iter())?; + let or_values = or.evaluate(medrecord, values.into_iter())?; + + Ok(Box::new( + either_values.chain(or_values).unique_by(|value| value.0), + )) + } } #[derive(Debug, Clone)] -pub enum MedRecordValueOperation { - LessThan { - value: MedRecordValueComparisonOperand, +pub enum SingleValueOperation { + SingleValueComparisonOperation { + operand: SingleValueComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleValuesComparisonOperation { + operand: MultipleValuesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleValueComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + IsFloat, + IsBool, + IsDateTime, + IsNull, + + EitherOr { + either: Wrapper, + or: Wrapper, }, } -impl DeepClone for MedRecordValueOperation { +impl DeepClone for SingleValueOperation { fn deep_clone(&self) -> Self { - // TODO - self.clone() + match self { + Self::SingleValueComparisonOperation { operand, kind } => { + Self::SingleValueComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::MultipleValuesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsFloat => Self::IsFloat, + Self::IsBool => Self::IsBool, + Self::IsDateTime => Self::IsDateTime, + Self::IsNull => Self::IsNull, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } } } -impl MedRecordValueOperation { - pub(crate) fn evaluate<'a>( +impl SingleValueOperation { + pub(crate) fn evaluate( &self, - medrecord: &'a MedRecord, - value: &'a MedRecordValue, - ) -> MedRecordResult { + medrecord: &MedRecord, + value: MedRecordValue, + ) -> MedRecordResult> { match self { - Self::LessThan { value: operand } => { - Self::evaluate_less_than(medrecord, value, operand) + Self::SingleValueComparisonOperation { operand, kind } => { + Self::evaluate_single_value_comparison_operation(medrecord, value, operand, kind) + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_values_comparison_operation(medrecord, value, operand, kind) } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, value, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Some(match kind { + UnaryArithmeticKind::Round => value.round(), + UnaryArithmeticKind::Ceil => value.ceil(), + UnaryArithmeticKind::Floor => value.floor(), + UnaryArithmeticKind::Abs => value.abs(), + UnaryArithmeticKind::Sqrt => value.sqrt(), + UnaryArithmeticKind::Trim => value.trim(), + UnaryArithmeticKind::TrimStart => value.trim_start(), + UnaryArithmeticKind::TrimEnd => value.trim_end(), + UnaryArithmeticKind::Lowercase => value.lowercase(), + UnaryArithmeticKind::Uppercase => value.uppercase(), + })), + Self::Slice(range) => Ok(Some(value.slice(range.clone()))), + Self::IsString => Ok(match value { + MedRecordValue::String(_) => Some(value), + _ => None, + }), + Self::IsInt => Ok(match value { + MedRecordValue::Int(_) => Some(value), + _ => None, + }), + Self::IsFloat => Ok(match value { + MedRecordValue::Float(_) => Some(value), + _ => None, + }), + Self::IsBool => Ok(match value { + MedRecordValue::Bool(_) => Some(value), + _ => None, + }), + Self::IsDateTime => Ok(match value { + MedRecordValue::DateTime(_) => Some(value), + _ => None, + }), + Self::IsNull => Ok(match value { + MedRecordValue::Null => Some(value), + _ => None, + }), + Self::EitherOr { either, or } => Self::evaluate_either_or(medrecord, value, either, or), } } - fn evaluate_less_than( + #[inline] + fn evaluate_single_value_comparison_operation( medrecord: &MedRecord, - value: &MedRecordValue, - comparison_operand: &MedRecordValueComparisonOperand, - ) -> MedRecordResult { - match comparison_operand { - MedRecordValueComparisonOperand::SingleOperand(comparison_operand) => { - let context = &comparison_operand.context.context; - let attribute = comparison_operand.context.attribute.clone(); - let kind = &comparison_operand.kind; - - let values = context - .get_values(medrecord, attribute)? - .map(|value| (&0, value)); + value: MedRecordValue, + comparison_operand: &SingleValueComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_value = + get_single_value_comparison_operand_value!(comparison_operand, medrecord); - let comparison_value = match kind { - ValueKind::Max => MedRecordValuesOperation::get_max(values), - ValueKind::Min => MedRecordValuesOperation::get_min(values), - }?; + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => value > comparison_value, + SingleComparisonKind::GreaterThanOrEqualTo => value >= comparison_value, + SingleComparisonKind::LessThan => value < comparison_value, + SingleComparisonKind::LessThanOrEqualTo => value <= comparison_value, + SingleComparisonKind::EqualTo => value == comparison_value, + SingleComparisonKind::NotEqualTo => value != comparison_value, + SingleComparisonKind::StartsWith => value.starts_with(&comparison_value), + SingleComparisonKind::EndsWith => value.ends_with(&comparison_value), + SingleComparisonKind::Contains => value.contains(&comparison_value), + }; - Ok(value < comparison_value.1) - } - MedRecordValueComparisonOperand::SingleValue(comparison_value) => { - Ok(value < comparison_value) - } - MedRecordValueComparisonOperand::MultipleOperand(comparison_operand) => { - let context = &comparison_operand.context; - let attribute = comparison_operand.attribute.clone(); + Ok(if comparison_result { Some(value) } else { None }) + } - let mut values = context.get_values(medrecord, attribute)?; + #[inline] + fn evaluate_multiple_values_comparison_operation( + medrecord: &MedRecord, + value: MedRecordValue, + comparison_operand: &MultipleValuesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_values = match comparison_operand { + MultipleValuesComparisonOperand::Operand(operand) => { + let context = &operand.context; + let attribute = operand.attribute.clone(); - Ok(values.all(|comparison_value| value < comparison_value)) - } - MedRecordValueComparisonOperand::MultipleValues(comparison_values) => { - Ok(comparison_values - .iter() - .all(|comparison_value| value < comparison_value)) + context + .get_values(medrecord, attribute)? + .collect::>() } + MultipleValuesComparisonOperand::Values(values) => values.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_values.contains(&value), + MultipleComparisonKind::IsNotIn => !comparison_values.contains(&value), + }; + + Ok(if comparison_result { Some(value) } else { None }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + value: MedRecordValue, + operand: &SingleValueComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_value = get_single_value_comparison_operand_value!(operand, medrecord); + + match kind { + BinaryArithmeticKind::Add => value.add(arithmetic_value), + BinaryArithmeticKind::Sub => value.sub(arithmetic_value), + BinaryArithmeticKind::Mul => value.mul(arithmetic_value), + BinaryArithmeticKind::Div => value.div(arithmetic_value), + BinaryArithmeticKind::Pow => value.pow(arithmetic_value), + BinaryArithmeticKind::Mod => value.r#mod(arithmetic_value), + } + .map(Some) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + value: MedRecordValue, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, value.clone())?; + let or_result = or.evaluate(medrecord, value)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), } } } diff --git a/rustmodels/src/medrecord/mod.rs b/rustmodels/src/medrecord/mod.rs index eff9caf9..baa4f2f6 100644 --- a/rustmodels/src/medrecord/mod.rs +++ b/rustmodels/src/medrecord/mod.rs @@ -637,7 +637,7 @@ impl PyMedRecord { .map(|node_index| { let neighbors = self .0 - .neighbors(&node_index) + .neighbors_outgoing(&node_index) .map_err(PyMedRecordError::from)? .map(|neighbor| neighbor.clone().into()) .collect();