diff --git a/.gitignore b/.gitignore index 293baf5a..263f9073 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ target .python-version docs/api/_autosummary/* .coverage +supply-chain diff --git a/crates/medmodels-core/src/medrecord/datatypes/mod.rs b/crates/medmodels-core/src/medrecord/datatypes/mod.rs index cb8a0932..c9f7110b 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/mod.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/mod.rs @@ -104,9 +104,9 @@ impl Display for DataType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { DataType::String => write!(f, "String"), - DataType::Int => write!(f, "Integer"), + DataType::Int => write!(f, "Int"), DataType::Float => write!(f, "Float"), - DataType::Bool => write!(f, "Boolean"), + DataType::Bool => write!(f, "Bool"), DataType::DateTime => write!(f, "DateTime"), DataType::Duration => write!(f, "Duration"), DataType::Null => write!(f, "Null"), @@ -376,14 +376,14 @@ mod test { #[test] fn test_display() { assert_eq!("String", format!("{}", DataType::String)); - assert_eq!("Integer", format!("{}", DataType::Int)); + assert_eq!("Int", format!("{}", DataType::Int)); assert_eq!("Float", format!("{}", DataType::Float)); - assert_eq!("Boolean", format!("{}", DataType::Bool)); + assert_eq!("Bool", format!("{}", DataType::Bool)); assert_eq!("DateTime", format!("{}", DataType::DateTime)); assert_eq!("Null", format!("{}", DataType::Null)); assert_eq!("Any", format!("{}", DataType::Any)); assert_eq!( - "Union[String, Integer]", + "Union[String, Int]", format!( "{}", DataType::Union((Box::new(DataType::String), Box::new(DataType::Int))) diff --git a/crates/medmodels-core/src/medrecord/example_dataset/mod.rs b/crates/medmodels-core/src/medrecord/example_dataset/mod.rs index 4b85f40d..7d3145e3 100644 --- a/crates/medmodels-core/src/medrecord/example_dataset/mod.rs +++ b/crates/medmodels-core/src/medrecord/example_dataset/mod.rs @@ -1,4 +1,8 @@ -use super::{datatypes::DataType, AttributeType, GroupSchema, MedRecordAttribute, Schema}; +use super::{ + datatypes::DataType, + schema::{GroupSchema, Schema}, + AttributeSchema, AttributeType, MedRecordAttribute, +}; use crate::MedRecord; use polars::{ io::SerReader, @@ -8,28 +12,32 @@ use std::{collections::HashMap, io::Cursor, sync::Arc}; macro_rules! simple_dataset_schema { () => { - Schema { - groups: HashMap::from([ + Schema::new_provided( + HashMap::from([ ( "diagnosis".into(), - GroupSchema { - nodes: HashMap::from([("description".into(), DataType::String.into())]), - edges: HashMap::new(), - strict: Some(true), - }, + GroupSchema::new( + AttributeSchema::from([( + "description".into(), + (DataType::String, AttributeType::Unstructured).into(), + )]), + AttributeSchema::default(), + ), ), ( "drug".into(), - GroupSchema { - nodes: HashMap::from([("description".into(), DataType::String.into())]), - edges: HashMap::new(), - strict: Some(true), - }, + GroupSchema::new( + AttributeSchema::from([( + "description".into(), + (DataType::String, AttributeType::Unstructured).into(), + )]), + AttributeSchema::default(), + ), ), ( "patient".into(), - GroupSchema { - nodes: HashMap::from([ + GroupSchema::new( + AttributeSchema::from([ ( "gender".into(), (DataType::String, AttributeType::Categorical).into(), @@ -39,23 +47,24 @@ macro_rules! simple_dataset_schema { (DataType::Int, AttributeType::Continuous).into(), ), ]), - edges: HashMap::new(), - strict: Some(true), - }, + AttributeSchema::default(), + ), ), ( "procedure".into(), - GroupSchema { - nodes: HashMap::from([("description".into(), DataType::String.into())]), - edges: HashMap::new(), - strict: Some(true), - }, + GroupSchema::new( + AttributeSchema::from([( + "description".into(), + (DataType::String, AttributeType::Unstructured).into(), + )]), + AttributeSchema::default(), + ), ), ( "patient_diagnosis".into(), - GroupSchema { - nodes: HashMap::new(), - edges: HashMap::from([ + GroupSchema::new( + AttributeSchema::default(), + AttributeSchema::from([ ( "time".into(), (DataType::DateTime, AttributeType::Temporal).into(), @@ -69,14 +78,13 @@ macro_rules! simple_dataset_schema { .into(), ), ]), - strict: Some(true), - }, + ), ), ( "patient_drug".into(), - GroupSchema { - nodes: HashMap::new(), - edges: HashMap::from([ + GroupSchema::new( + AttributeSchema::default(), + AttributeSchema::from([ ( "time".into(), (DataType::DateTime, AttributeType::Temporal).into(), @@ -90,14 +98,13 @@ macro_rules! simple_dataset_schema { (DataType::Float, AttributeType::Continuous).into(), ), ]), - strict: Some(true), - }, + ), ), ( "patient_procedure".into(), - GroupSchema { - nodes: HashMap::new(), - edges: HashMap::from([ + GroupSchema::new( + AttributeSchema::default(), + AttributeSchema::from([ ( "time".into(), (DataType::DateTime, AttributeType::Temporal).into(), @@ -107,40 +114,42 @@ macro_rules! simple_dataset_schema { (DataType::Float, AttributeType::Continuous).into(), ), ]), - strict: Some(true), - }, + ), ), ]), - default: None, - strict: Some(true), - } + GroupSchema::new(Default::default(), Default::default()), + ) }; } macro_rules! advanced_dataset_schema { () => { - Schema { - groups: HashMap::from([ + Schema::new_provided( + HashMap::from([ ( "diagnosis".into(), - GroupSchema { - nodes: HashMap::from([("description".into(), DataType::String.into())]), - edges: HashMap::new(), - strict: Some(true), - }, + GroupSchema::new( + AttributeSchema::from([( + "description".into(), + (DataType::String, AttributeType::Unstructured).into(), + )]), + AttributeSchema::default(), + ), ), ( "drug".into(), - GroupSchema { - nodes: HashMap::from([("description".into(), DataType::String.into())]), - edges: HashMap::new(), - strict: Some(true), - }, + GroupSchema::new( + AttributeSchema::from([( + "description".into(), + (DataType::String, AttributeType::Unstructured).into(), + )]), + AttributeSchema::default(), + ), ), ( "patient".into(), - GroupSchema { - nodes: HashMap::from([ + GroupSchema::new( + AttributeSchema::from([ ( "gender".into(), (DataType::String, AttributeType::Categorical).into(), @@ -150,31 +159,28 @@ macro_rules! advanced_dataset_schema { (DataType::Int, AttributeType::Continuous).into(), ), ]), - edges: HashMap::new(), - strict: Some(true), - }, + AttributeSchema::default(), + ), ), ( "procedure".into(), - GroupSchema { - nodes: HashMap::from([("description".into(), DataType::String.into())]), - edges: HashMap::new(), - strict: Some(true), - }, + GroupSchema::new( + AttributeSchema::from([( + "description".into(), + (DataType::String, AttributeType::Unstructured).into(), + )]), + AttributeSchema::default(), + ), ), ( "event".into(), - GroupSchema { - nodes: HashMap::new(), - edges: HashMap::new(), - strict: Some(true), - }, + GroupSchema::new(AttributeSchema::default(), AttributeSchema::default()), ), ( "patient_diagnosis".into(), - GroupSchema { - nodes: HashMap::new(), - edges: HashMap::from([ + GroupSchema::new( + AttributeSchema::default(), + AttributeSchema::from([ ( "time".into(), (DataType::DateTime, AttributeType::Temporal).into(), @@ -188,14 +194,13 @@ macro_rules! advanced_dataset_schema { .into(), ), ]), - strict: Some(true), - }, + ), ), ( "patient_drug".into(), - GroupSchema { - nodes: HashMap::new(), - edges: HashMap::from([ + GroupSchema::new( + AttributeSchema::default(), + AttributeSchema::from([ ( "time".into(), (DataType::DateTime, AttributeType::Temporal).into(), @@ -209,14 +214,13 @@ macro_rules! advanced_dataset_schema { (DataType::Float, AttributeType::Continuous).into(), ), ]), - strict: Some(true), - }, + ), ), ( "patient_procedure".into(), - GroupSchema { - nodes: HashMap::new(), - edges: HashMap::from([ + GroupSchema::new( + AttributeSchema::default(), + AttributeSchema::from([ ( "time".into(), (DataType::DateTime, AttributeType::Temporal).into(), @@ -226,24 +230,21 @@ macro_rules! advanced_dataset_schema { (DataType::Float, AttributeType::Continuous).into(), ), ]), - strict: Some(true), - }, + ), ), ( "patient_event".into(), - GroupSchema { - nodes: HashMap::new(), - edges: HashMap::from([( + GroupSchema::new( + AttributeSchema::default(), + AttributeSchema::from([( "time".into(), (DataType::DateTime, AttributeType::Temporal).into(), )]), - strict: Some(true), - }, + ), ), ]), - default: None, - strict: Some(true), - } + GroupSchema::new(Default::default(), Default::default()), + ) }; } @@ -460,7 +461,7 @@ impl MedRecord { ) .expect("Group can be added"); - medrecord.schema = simple_dataset_schema!(); + unsafe { medrecord.update_schema_unchecked(&mut simple_dataset_schema!()) }; medrecord } @@ -695,7 +696,7 @@ impl MedRecord { .add_group("patient_event".into(), None, Some(patient_event_ids)) .expect("Group can be added"); - medrecord.schema = advanced_dataset_schema!(); + unsafe { medrecord.update_schema_unchecked(&mut advanced_dataset_schema!()) }; medrecord } @@ -703,8 +704,14 @@ impl MedRecord { #[cfg(test)] mod test { - use super::{AttributeType, DataType, GroupSchema, Schema}; - use crate::MedRecord; + use super::{AttributeType, DataType}; + use crate::{ + medrecord::{ + schema::{GroupSchema, Schema}, + AttributeSchema, + }, + MedRecord, + }; use std::collections::HashMap; #[test] diff --git a/crates/medmodels-core/src/medrecord/group_mapping.rs b/crates/medmodels-core/src/medrecord/group_mapping.rs index fd0de078..4d847e00 100644 --- a/crates/medmodels-core/src/medrecord/group_mapping.rs +++ b/crates/medmodels-core/src/medrecord/group_mapping.rs @@ -7,10 +7,10 @@ pub type Group = MedRecordAttribute; #[derive(Debug, Serialize, Deserialize, Clone)] pub(super) struct GroupMapping { - nodes_in_group: MrHashMap>, - edges_in_group: MrHashMap>, - groups_of_node: MrHashMap>, - groups_of_edge: MrHashMap>, + pub(super) nodes_in_group: MrHashMap>, + pub(super) edges_in_group: MrHashMap>, + pub(super) groups_of_node: MrHashMap>, + pub(super) groups_of_edge: MrHashMap>, } impl GroupMapping { diff --git a/crates/medmodels-core/src/medrecord/mod.rs b/crates/medmodels-core/src/medrecord/mod.rs index 996b87c0..2e4d4ea8 100644 --- a/crates/medmodels-core/src/medrecord/mod.rs +++ b/crates/medmodels-core/src/medrecord/mod.rs @@ -30,7 +30,7 @@ pub use self::{ }, wrapper::{CardinalityWrapper, Wrapper}, }, - schema::{AttributeDataType, AttributeType, GroupSchema, Schema}, + schema::{AttributeDataType, AttributeSchema, AttributeType, GroupSchema, Schema, SchemaType}, }; use crate::errors::MedRecordError; use ::polars::frame::DataFrame; @@ -39,7 +39,11 @@ use group_mapping::GroupMapping; use polars::{dataframe_to_edges, dataframe_to_nodes}; use querying::{edges::EdgeSelection, nodes::NodeSelection}; use serde::{Deserialize, Serialize}; -use std::{fs, mem, path::Path}; +use std::{ + collections::{hash_map::Entry, HashMap}, + fs, mem, + path::Path, +}; pub struct NodeDataFrameInput { dataframe: DataFrame, @@ -141,7 +145,7 @@ impl MedRecord { Self { graph: Graph::new(), group_mapping: GroupMapping::new(), - schema: Schema::default(), + schema: Default::default(), } } @@ -241,78 +245,144 @@ impl MedRecord { }) } - pub fn update_schema(&mut self, schema: Schema) -> Result<(), MedRecordError> { - let mut old_schema = schema; - - mem::swap(&mut self.schema, &mut old_schema); - - let result = self - .graph - .nodes - .iter() - .map(|(node_index, node)| { - let groups_of_node = self - .groups_of_node(node_index) - .expect("groups of node must exist") - .collect::>(); - - if !groups_of_node.is_empty() { - for group in groups_of_node { - self.schema - .validate_node(node_index, &node.attributes, Some(group))?; + pub fn update_schema(&mut self, mut schema: Schema) -> Result<(), MedRecordError> { + let mut nodes_group_cache = HashMap::<&Group, usize>::new(); + let mut nodes_default_visited = false; + let mut edges_group_cache = HashMap::<&Group, usize>::new(); + let mut edges_default_visited = false; + + for (node_index, node) in self.graph.nodes.iter() { + let groups_of_node = self + .groups_of_node(node_index) + .expect("groups of node must exist") + .collect::>(); + + if !groups_of_node.is_empty() { + for group in groups_of_node { + match schema.schema_type() { + SchemaType::Inferred => match nodes_group_cache.entry(group) { + Entry::Occupied(entry) => { + schema.update_node( + &node.attributes, + Some(group), + *entry.get() == 0, + ); + } + Entry::Vacant(entry) => { + entry.insert( + self.group_mapping + .nodes_in_group + .get(group) + .map(|nodes| nodes.len()) + .unwrap_or(0), + ); + } + }, + SchemaType::Provided => { + schema.validate_node(node_index, &node.attributes, Some(group))? + } } - } else { - self.schema - .validate_node(node_index, &node.attributes, None)?; } + } else { + match schema.schema_type() { + SchemaType::Inferred => { + let nodes_in_groups = self.group_mapping.nodes_in_group.len(); - Ok(()) - }) - .collect::, MedRecordError>>(); + let nodes_not_in_groups = self.graph.node_count() - nodes_in_groups; - if let Err(error) = result { - self.schema = old_schema; + schema.update_node( + &node.attributes, + None, + nodes_not_in_groups == 0 || !nodes_default_visited, + ); - return Err(error); + nodes_default_visited = true; + } + SchemaType::Provided => { + schema.validate_node(node_index, &node.attributes, None)?; + } + } + } } - let result = self - .graph - .edges - .iter() - .map(|(edge_index, edge)| { - let groups_of_edge = self - .groups_of_edge(edge_index) - .expect("groups of edge must exist") - .collect::>(); - - if !groups_of_edge.is_empty() { - for group in groups_of_edge { - self.schema - .validate_edge(edge_index, &edge.attributes, Some(group))?; + for (edge_index, edge) in self.graph.edges.iter() { + let groups_of_edge = self + .groups_of_edge(edge_index) + .expect("groups of edge must exist") + .collect::>(); + + if !groups_of_edge.is_empty() { + for group in groups_of_edge { + match schema.schema_type() { + SchemaType::Inferred => match edges_group_cache.entry(group) { + Entry::Occupied(entry) => { + schema.update_edge( + &edge.attributes, + Some(group), + *entry.get() == 0, + ); + } + Entry::Vacant(entry) => { + entry.insert( + self.group_mapping + .edges_in_group + .get(group) + .map(|edges| edges.len()) + .unwrap_or(0), + ); + } + }, + SchemaType::Provided => { + schema.validate_edge(edge_index, &edge.attributes, Some(group))?; + } } - } else { - self.schema - .validate_edge(edge_index, &edge.attributes, None)?; } + } else { + match schema.schema_type() { + SchemaType::Inferred => { + let edges_in_groups = self.group_mapping.edges_in_group.len(); - Ok(()) - }) - .collect::, MedRecordError>>(); + let edges_not_in_groups = self.graph.edge_count() - edges_in_groups; - if let Err(error) = result { - self.schema = old_schema; + schema.update_edge( + &edge.attributes, + None, + edges_not_in_groups == 0 || !edges_default_visited, + ); - return Err(error); + edges_default_visited = true; + } + SchemaType::Provided => { + schema.validate_edge(edge_index, &edge.attributes, None)?; + } + } + } } + mem::swap(&mut self.schema, &mut schema); + Ok(()) } - pub fn get_schema(&self) -> &Schema { + /// # Safety + /// + /// This function should only be used if the data has been validated against the schema. + pub unsafe fn update_schema_unchecked(&mut self, schema: &mut Schema) { + mem::swap(&mut self.schema, schema); + } + + pub fn schema(&self) -> &Schema { &self.schema } + pub fn freeze_schema(&mut self) { + self.schema.freeze(); + } + + pub fn unfreeze_schema(&mut self) { + self.schema.unfreeze(); + } + pub fn node_indices(&self) -> impl Iterator { self.graph.node_indices() } @@ -401,7 +471,19 @@ impl MedRecord { node_index: NodeIndex, attributes: Attributes, ) -> Result<(), MedRecordError> { - self.schema.validate_node(&node_index, &attributes, None)?; + match self.schema.schema_type() { + SchemaType::Inferred => { + let nodes_in_groups = self.group_mapping.nodes_in_group.len(); + + let nodes_not_in_groups = self.graph.node_count() - nodes_in_groups; + + self.schema + .update_node(&attributes, None, nodes_not_in_groups == 0); + } + SchemaType::Provided => { + self.schema.validate_node(&node_index, &attributes, None)?; + } + } self.graph .add_node(node_index, attributes) @@ -454,14 +536,28 @@ impl MedRecord { .add_edge(source_node_index, target_node_index, attributes.to_owned()) .map_err(MedRecordError::from)?; - match self.schema.validate_edge(&edge_index, &attributes, None) { - Ok(()) => Ok(edge_index), - Err(e) => { - self.graph - .remove_edge(&edge_index) - .expect("Edge must exist"); + match self.schema.schema_type() { + SchemaType::Inferred => { + let edges_in_groups = self.group_mapping.edges_in_group.len(); + + let edges_not_in_groups = self.graph.edge_count() - edges_in_groups; - Err(e.into()) + self.schema + .update_edge(&attributes, None, edges_not_in_groups == 0); + + Ok(edge_index) + } + SchemaType::Provided => { + match self.schema.validate_edge(&edge_index, &attributes, None) { + Ok(()) => Ok(edge_index), + Err(e) => { + self.graph + .remove_edge(&edge_index) + .expect("Edge must exist"); + + Err(e.into()) + } + } } } } @@ -587,8 +683,23 @@ impl MedRecord { ) -> Result<(), MedRecordError> { let node_attributes = self.graph.node_attributes(&node_index)?; - self.schema - .validate_node(&node_index, node_attributes, Some(&group))?; + match self.schema.schema_type() { + SchemaType::Inferred => { + let nodes_in_group = self + .group_mapping + .nodes_in_group + .get(&group) + .map(|nodes| nodes.len()) + .unwrap_or(0); + + self.schema + .update_node(node_attributes, Some(&group), nodes_in_group == 0); + } + SchemaType::Provided => { + self.schema + .validate_node(&node_index, node_attributes, Some(&group))?; + } + } self.group_mapping.add_node_to_group(group, node_index) } @@ -600,8 +711,23 @@ impl MedRecord { ) -> Result<(), MedRecordError> { let edge_attributes = self.graph.edge_attributes(&edge_index)?; - self.schema - .validate_edge(&edge_index, edge_attributes, Some(&group))?; + match self.schema.schema_type() { + SchemaType::Inferred => { + let edges_in_group = self + .group_mapping + .edges_in_group + .get(&group) + .map(|edges| edges.len()) + .unwrap_or(0); + + self.schema + .update_edge(edge_attributes, Some(&group), edges_in_group == 0); + } + SchemaType::Provided => { + self.schema + .validate_edge(&edge_index, edge_attributes, Some(&group))?; + } + } self.group_mapping.add_edge_to_group(group, edge_index) } @@ -762,10 +888,11 @@ impl Default for MedRecord { #[cfg(test)] mod test { - use super::{ - Attributes, DataType, GroupSchema, MedRecord, MedRecordAttribute, NodeIndex, Schema, + use super::{Attributes, DataType, MedRecord, MedRecordAttribute, NodeIndex}; + use crate::{ + errors::MedRecordError, + medrecord::schema::{AttributeSchema, GroupSchema, Schema}, }; - use crate::errors::MedRecordError; use polars::prelude::{DataFrame, NamedFrom, PolarsError, Series}; use std::{collections::HashMap, fs}; @@ -834,116 +961,6 @@ mod test { MedRecord::from_tuples(nodes, Some(edges), None).unwrap() } - #[test] - fn test_schema() { - let schema = Schema { - groups: HashMap::from([( - "group".into(), - GroupSchema { - nodes: HashMap::from([("attribute2".into(), DataType::Int.into())]), - edges: HashMap::from([("attribute2".into(), DataType::Int.into())]), - strict: None, - }, - )]), - default: Some(GroupSchema { - nodes: HashMap::from([("attribute".into(), DataType::Int.into())]), - edges: HashMap::from([("attribute".into(), DataType::Int.into())]), - strict: None, - }), - strict: None, - }; - - let mut medrecord = MedRecord::with_schema(schema.clone()); - medrecord.add_group("group".into(), None, None).unwrap(); - - assert_eq!(schema, *medrecord.get_schema()); - - assert!(medrecord - .add_node("0".into(), HashMap::from([("attribute".into(), 1.into())])) - .is_ok()); - - assert!(medrecord - .add_node( - "1".into(), - HashMap::from([("attribute".into(), "1".into())]) - ) - .is_err_and(|e| matches!(e, MedRecordError::SchemaError(_)))); - - medrecord - .add_node( - "1".into(), - HashMap::from([ - ("attribute".into(), 1.into()), - ("attribute2".into(), 1.into()), - ]), - ) - .unwrap(); - - assert!(medrecord - .add_node_to_group("group".into(), "1".into()) - .is_ok()); - - medrecord - .add_node( - "2".into(), - HashMap::from([ - ("attribute".into(), 1.into()), - ("attribute2".into(), "1".into()), - ]), - ) - .unwrap(); - - assert!(medrecord - .add_node_to_group("group".into(), "2".into()) - .is_err_and(|e| { matches!(e, MedRecordError::SchemaError(_)) })); - - assert!(medrecord - .add_edge( - "0".into(), - "1".into(), - HashMap::from([("attribute".into(), 1.into())]) - ) - .is_ok()); - - assert!(medrecord - .add_edge( - "0".into(), - "1".into(), - HashMap::from([("attribute".into(), "1".into())]) - ) - .is_err_and(|e| matches!(e, MedRecordError::SchemaError(_)))); - - let edge_index = medrecord - .add_edge( - "0".into(), - "1".into(), - HashMap::from([ - ("attribute".into(), 1.into()), - ("attribute2".into(), 1.into()), - ]), - ) - .unwrap(); - - assert!(medrecord - .add_edge_to_group("group".into(), edge_index) - .is_ok()); - - let edge_index = medrecord - .add_edge( - "0".into(), - "1".into(), - HashMap::from([ - ("attribute".into(), 1.into()), - ("attribute2".into(), "1".into()), - ]), - ) - .unwrap(); - - assert!(medrecord - .add_edge_to_group("group".into(), edge_index) - .is_err_and(|e| { matches!(e, MedRecordError::SchemaError(_)) })); - } - #[test] fn test_from_tuples() { let medrecord = create_medrecord(); @@ -1037,19 +1054,17 @@ mod test { ) .unwrap(); - let schema = Schema { - groups: Default::default(), - default: Some(GroupSchema { - nodes: HashMap::from([("attribute".into(), DataType::Int.into())]), - edges: HashMap::from([("attribute".into(), DataType::Int.into())]), - strict: None, - }), - strict: None, - }; + let schema = Schema::new_provided( + Default::default(), + GroupSchema::new( + AttributeSchema::from([("attribute".into(), DataType::Int.into())]), + AttributeSchema::from([("attribute".into(), DataType::Int.into())]), + ), + ); assert!(medrecord.update_schema(schema.clone()).is_ok()); - assert_eq!(schema, *medrecord.get_schema()); + assert_eq!(schema, *medrecord.schema()); } #[test] @@ -1060,21 +1075,21 @@ mod test { .add_node("0".into(), HashMap::from([("attribute2".into(), 1.into())])) .unwrap(); - let schema = Schema { - groups: Default::default(), - default: Some(GroupSchema { - nodes: HashMap::from([("attribute".into(), DataType::Int.into())]), - edges: HashMap::from([("attribute".into(), DataType::Int.into())]), - strict: None, - }), - strict: None, - }; + let schema = Schema::new_provided( + Default::default(), + GroupSchema::new( + AttributeSchema::from([("attribute".into(), DataType::Int.into())]), + AttributeSchema::from([("attribute".into(), DataType::Int.into())]), + ), + ); + + let previous_schema = medrecord.schema().clone(); assert!(medrecord .update_schema(schema.clone()) .is_err_and(|e| { matches!(e, MedRecordError::SchemaError(_)) })); - assert_eq!(Schema::default(), *medrecord.get_schema()); + assert_eq!(previous_schema, *medrecord.schema()); let mut medrecord = MedRecord::new(); @@ -1092,11 +1107,13 @@ mod test { ) .unwrap(); + let previous_schema = medrecord.schema().clone(); + assert!(medrecord .update_schema(schema.clone()) .is_err_and(|e| { matches!(e, MedRecordError::SchemaError(_)) })); - assert_eq!(Schema::default(), *medrecord.get_schema()); + assert_eq!(previous_schema, *medrecord.schema()); } #[test] diff --git a/crates/medmodels-core/src/medrecord/schema.rs b/crates/medmodels-core/src/medrecord/schema.rs index 8015870e..9d0f4839 100644 --- a/crates/medmodels-core/src/medrecord/schema.rs +++ b/crates/medmodels-core/src/medrecord/schema.rs @@ -1,38 +1,145 @@ -use super::{Attributes, EdgeIndex, NodeIndex}; +use super::{Attributes, EdgeIndex, Group, MedRecord, NodeIndex}; use crate::{ errors::GraphError, - medrecord::{datatypes::DataType, Group, MedRecordAttribute}, + medrecord::{datatypes::DataType, MedRecordAttribute}, }; +use medmodels_utils::aliases::MrHashMap; use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{hash_map::Entry, HashMap}, + ops::Deref, +}; -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq)] pub enum AttributeType { Categorical, Continuous, Temporal, + Unstructured, +} + +impl AttributeType { + pub fn infer(data_type: &DataType) -> Self { + match data_type { + DataType::String => Self::Unstructured, + DataType::Int => Self::Continuous, + DataType::Float => Self::Continuous, + DataType::Bool => Self::Categorical, + DataType::DateTime => Self::Temporal, + DataType::Duration => Self::Continuous, + DataType::Null => Self::Unstructured, + DataType::Any => Self::Unstructured, + DataType::Union((first_dataype, second_dataype)) => { + Self::infer(first_dataype).merge(&Self::infer(second_dataype)) + } + DataType::Option(dataype) => Self::infer(dataype), + } + } + + fn merge(&self, other: &Self) -> Self { + match (self, other) { + (Self::Categorical, Self::Categorical) => Self::Categorical, + (Self::Continuous, Self::Continuous) => Self::Continuous, + (Self::Temporal, Self::Temporal) => Self::Temporal, + _ => Self::Unstructured, + } + } +} + +impl DataType { + fn merge(&self, other: &Self) -> Self { + if self.evaluate(other) { + self.clone() + } else { + match (self, other) { + (Self::Null, _) => Self::Option(Box::new(other.clone())), + (_, Self::Null) => Self::Option(Box::new(self.clone())), + (_, Self::Any) => Self::Any, + (Self::Any, _) => Self::Any, + (Self::Option(option1), Self::Option(option2)) => { + Self::Option(Box::new(option1.merge(option2))) + } + (Self::Option(option), _) => Self::Option(Box::new(option.merge(other))), + (_, Self::Option(option)) => Self::Option(Box::new(self.merge(option))), + _ => Self::Union((Box::new(self.clone()), Box::new(other.clone()))), + } + } + } } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct AttributeDataType { - pub data_type: DataType, - pub attribute_type: Option, + data_type: DataType, + attribute_type: AttributeType, } impl AttributeDataType { - pub fn new(data_type: DataType, attribute_type: Option) -> Self { - Self { + fn validate(data_type: &DataType, attribute_type: &AttributeType) -> Result<(), GraphError> { + match (attribute_type, data_type) { + (AttributeType::Categorical, _) => Ok(()), + (AttributeType::Unstructured, _) => Ok(()), + + (_, DataType::Option(option)) => Self::validate(option, attribute_type), + (_, DataType::Union((first_datatype, second_datatype))) => { + Self::validate(first_datatype, attribute_type)?; + Self::validate(second_datatype, attribute_type) + } + + (AttributeType::Continuous, DataType::Int | DataType::Float | DataType::Null) => Ok(()), + (AttributeType::Continuous, _) => Err(GraphError::SchemaError( + "Continuous attribute must be of (sub-)type Int or Float.".to_string(), + )), + + (AttributeType::Temporal, DataType::DateTime | DataType::Duration | DataType::Null) => { + Ok(()) + } + (AttributeType::Temporal, _) => Err(GraphError::SchemaError( + "Temporal attribute must be of (sub-)type DateTime or Duration.".to_string(), + )), + } + } + + pub fn new(data_type: DataType, attribute_type: AttributeType) -> Result { + Self::validate(&data_type, &attribute_type)?; + + Ok(Self { data_type, attribute_type, + }) + } + + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + pub fn attribute_type(&self) -> &AttributeType { + &self.attribute_type + } + + fn merge(&mut self, other: &Self) { + match (self.data_type.clone(), other.data_type.clone()) { + (DataType::Null, _) => { + self.data_type = self.data_type.merge(&other.data_type); + self.attribute_type = other.attribute_type; + } + (_, DataType::Null) => { + self.data_type = self.data_type.merge(&other.data_type); + } + _ => { + self.data_type = self.data_type.merge(&other.data_type); + self.attribute_type = self.attribute_type.merge(&other.attribute_type); + } } } } impl From for AttributeDataType { fn from(value: DataType) -> Self { + let attribute_type = AttributeType::infer(&value); + Self { data_type: value, - attribute_type: None, + attribute_type, } } } @@ -41,169 +148,378 @@ impl From<(DataType, AttributeType)> for AttributeDataType { fn from(value: (DataType, AttributeType)) -> Self { Self { data_type: value.0, - attribute_type: Some(value.1), + attribute_type: value.1, } } } -type AttributeSchema = HashMap; +enum AttributeSchemaKind<'a> { + Node(&'a NodeIndex), + Edge(&'a EdgeIndex), +} -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct GroupSchema { - pub nodes: AttributeSchema, - pub edges: AttributeSchema, - pub strict: Option, +impl AttributeSchemaKind<'_> { + fn error_message(&self, key: &MedRecordAttribute, data_type: &DataType) -> String { + match self { + Self::Node(index) => format!( + "Attribute {} of type {} not found on node with index {}", + key, data_type, index + ), + Self::Edge(index) => format!( + "Attribute {} of type {} not found on edge with index {}", + key, data_type, index + ), + } + } + + fn error_message_expected( + &self, + key: &MedRecordAttribute, + data_type: &DataType, + expected_data_type: &DataType, + ) -> String { + match self { + Self::Node(index) => format!( + "Attribute {} of node with index {} is of type {}. Expected {}.", + key, index, data_type, expected_data_type + ), + Self::Edge(index) => format!( + "Attribute {} of node with index {} is of type {}. Expected {}.", + key, index, data_type, expected_data_type + ), + } + } + + fn error_message_too_many(&self, attributes: Vec) -> String { + match self { + Self::Node(index) => format!( + "Attributes [{}] of node with index {} do not exist in schema.", + attributes.join(", "), + index + ), + Self::Edge(index) => format!( + "Attributes [{}] of edge with index {} do not exist in schema.", + attributes.join(", "), + index + ), + } + } } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Schema { - pub groups: HashMap, - pub default: Option, - pub strict: Option, +type AttributeSchemaMapping = HashMap; + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +pub struct AttributeSchema(AttributeSchemaMapping); + +impl Deref for AttributeSchema { + type Target = AttributeSchemaMapping; + + fn deref(&self) -> &Self::Target { + &self.0 + } } -impl GroupSchema { - pub fn validate_node<'a>( +impl From for AttributeSchema +where + T: Into, +{ + fn from(value: T) -> Self { + Self(value.into()) + } +} + +impl AttributeSchema { + pub fn new(mapping: HashMap) -> Self { + Self(mapping) + } + + fn validate( &self, - index: &'a NodeIndex, - attributes: &'a Attributes, - strict: bool, + attributes: &Attributes, + kind: AttributeSchemaKind, ) -> Result<(), GraphError> { - for (key, schema) in &self.nodes { - let value = attributes.get(key).ok_or(GraphError::SchemaError(format!( - "Attribute {} of type {} not found on node with index {}", - key, schema.data_type, index - )))?; + for (key, schema) in &self.0 { + let value = attributes.get(key).ok_or(GraphError::SchemaError( + kind.error_message(key, &schema.data_type), + ))?; let data_type = DataType::from(value); if !schema.data_type.evaluate(&data_type) { - return Err(GraphError::SchemaError(format!( - "Attribute {} of node with index {} is of type {}. Expected {}.", - key, index, data_type, schema.data_type + return Err(GraphError::SchemaError(kind.error_message_expected( + key, + &data_type, + &schema.data_type, ))); } } - if self.strict.unwrap_or(strict) { - let attributes = attributes.keys().collect::>(); - let schema_attributes = self.nodes.keys().collect::>(); - let attributes_not_in_schema = attributes - .difference(&schema_attributes) - .map(|attribute| attribute.to_string()) - .collect::>(); - - match attributes_not_in_schema.len() { - 0 => (), - 1 => { - let attribute_not_in_schema = attributes_not_in_schema - .first() - .expect("Attribute must exist."); - - return Err(GraphError::SchemaError(format!( - "Attribute {} of node with index {} does not exist in strict schema.", - attribute_not_in_schema, index - ))); - } - _ => { - return Err(GraphError::SchemaError(format!( - "Attributes {} of node with index {} do not exist in strict schema.", - attributes_not_in_schema.join(", "), - index - ))); - } + let attributes_not_in_schema = attributes + .keys() + .filter(|attribute| !self.0.contains_key(*attribute)) + .map(|attribute| attribute.to_string()) + .collect::>(); + + match attributes_not_in_schema.len() { + 0 => (), + _ => { + return Err(GraphError::SchemaError( + kind.error_message_too_many(attributes_not_in_schema), + )); } } Ok(()) } - pub fn validate_edge<'a>( - &self, - index: &'a EdgeIndex, - attributes: &'a Attributes, - strict: bool, - ) -> Result<(), GraphError> { - for (key, schema) in &self.edges { - let value = attributes.get(key).ok_or(GraphError::SchemaError(format!( - "Attribute {} of type {} not found on edge with index {}", - key, schema.data_type, index - )))?; + fn update(&mut self, attributes: &Attributes, empty: bool) { + for (attribute, data_type) in self.0.iter_mut() { + if !attributes.contains_key(attribute) { + data_type.data_type = data_type.data_type.merge(&DataType::Null); + } + } + for (attribute, value) in attributes { let data_type = DataType::from(value); + let attribute_type = AttributeType::infer(&data_type); - if !schema.data_type.evaluate(&data_type) { - return Err(GraphError::SchemaError(format!( - "Attribute {} of edge with index {} is of type {}. Expected {}.", - key, index, data_type, schema.data_type - ))); - } - } + let mut attribute_data_type = AttributeDataType::new(data_type, attribute_type) + .expect("AttributeType was infered from DataType."); - if self.strict.unwrap_or(strict) { - let attributes = attributes.keys().collect::>(); - let schema_attributes = self.edges.keys().collect::>(); - let attributes_not_in_schema = attributes - .difference(&schema_attributes) - .map(|attribute| attribute.to_string()) - .collect::>(); - - match attributes_not_in_schema.len() { - 0 => (), - 1 => { - let attribute_not_in_schema = attributes_not_in_schema - .first() - .expect("Attribute must exist."); - - return Err(GraphError::SchemaError(format!( - "Attribute {} of edge with index {} does not exist in strict schema.", - attribute_not_in_schema, index - ))); + match self.0.entry(attribute.clone()) { + Entry::Occupied(entry) => { + entry.into_mut().merge(&attribute_data_type); } - _ => { - return Err(GraphError::SchemaError(format!( - "Attributes {} of edge with index {} do not exist in strict schema.", - attributes_not_in_schema.join(", "), - index - ))); + Entry::Vacant(entry) => { + if !empty { + attribute_data_type.data_type = + attribute_data_type.data_type.merge(&DataType::Null); + } + + entry.insert(attribute_data_type); } } } + } - Ok(()) + fn infer(attributes: Vec<&Attributes>) -> Self { + let mut schema = Self::default(); + + let mut empty = true; + + for attributes in attributes { + schema.update(attributes, empty); + + empty = false; + } + + schema } } -impl Schema { +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +pub struct GroupSchema { + nodes: AttributeSchema, + edges: AttributeSchema, +} + +impl GroupSchema { + pub fn new(nodes: AttributeSchema, edges: AttributeSchema) -> Self { + Self { nodes, edges } + } + + pub fn nodes(&self) -> &AttributeSchemaMapping { + &self.nodes.0 + } + + pub fn edges(&self) -> &AttributeSchemaMapping { + &self.edges.0 + } + pub fn validate_node<'a>( &self, index: &'a NodeIndex, attributes: &'a Attributes, - group: Option<&'a Group>, ) -> Result<(), GraphError> { - let group_schema = group.and_then(|group| self.groups.get(group)); + self.nodes + .validate(attributes, AttributeSchemaKind::Node(index)) + } + + pub fn validate_edge<'a>( + &self, + index: &'a EdgeIndex, + attributes: &'a Attributes, + ) -> Result<(), GraphError> { + self.edges + .validate(attributes, AttributeSchemaKind::Edge(index)) + } - match (group_schema, &self.default, self.strict) { - (Some(group_schema), _, Some(true)) => { - group_schema.validate_node(index, attributes, true)?; + pub(crate) fn infer(nodes: Vec<&Attributes>, edges: Vec<&Attributes>) -> Self { + Self { + nodes: AttributeSchema::infer(nodes), + edges: AttributeSchema::infer(edges), + } + } - Ok(()) + pub(crate) fn update_node(&mut self, attributes: &Attributes, empty: bool) { + self.nodes.update(attributes, empty); + } + + pub(crate) fn update_edge(&mut self, attributes: &Attributes, empty: bool) { + self.edges.update(attributes, empty); + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum SchemaType { + Inferred, + Provided, +} + +impl Default for SchemaType { + fn default() -> Self { + Self::Inferred + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +pub struct Schema { + groups: HashMap, + default: GroupSchema, + schema_type: SchemaType, +} + +impl Schema { + pub fn new_inferred(groups: HashMap, default: GroupSchema) -> Self { + Self { + groups, + default, + schema_type: SchemaType::Inferred, + } + } + + pub fn new_provided(groups: HashMap, default: GroupSchema) -> Self { + Self { + groups, + default, + schema_type: SchemaType::Provided, + } + } + + pub fn infer(medrecord: &MedRecord) -> Self { + let mut group_mapping = medrecord + .groups() + .map(|group| (group, (Vec::new(), Vec::new()))) + .collect::>(); + + let mut default_group = (Vec::new(), Vec::new()); + + for node_index in medrecord.node_indices() { + let mut groups_of_node = medrecord + .groups_of_node(node_index) + .expect("Node must exist.") + .peekable(); + + if groups_of_node.peek().is_none() { + default_group.0.push(node_index); + continue; } - (Some(group_schema), _, _) => { - group_schema.validate_node(index, attributes, false)?; - Ok(()) + for group in groups_of_node { + let group_nodes = &mut group_mapping.get_mut(&group).expect("Group must exist.").0; + + group_nodes.push(node_index); } - (_, Some(defalt_schema), Some(true)) => { - defalt_schema.validate_node(index, attributes, true) + } + + for edge_index in medrecord.edge_indices() { + let mut groups_of_edge = medrecord + .groups_of_edge(edge_index) + .expect("Edge must exist.") + .peekable(); + + if groups_of_edge.peek().is_none() { + default_group.1.push(edge_index); + continue; } - (_, Some(default_schema), _) => default_schema.validate_node(index, attributes, false), - (None, None, None) | (None, None, Some(false)) => Ok(()), - _ => Err(GraphError::SchemaError(format!( - "No schema provided for node {} wit no group", - index - ))), + for group in groups_of_edge { + let group_edges = &mut group_mapping.get_mut(&group).expect("Group must exist.").1; + + group_edges.push(edge_index); + } + } + + let group_schemas = + group_mapping + .into_iter() + .map(|(group, (nodes_in_group, edges_in_group))| { + let node_attributes = nodes_in_group + .into_iter() + .map(|node| medrecord.node_attributes(node).expect("Node must exist.")) + .collect::>(); + let edge_attributes = edges_in_group + .into_iter() + .map(|edge| medrecord.edge_attributes(edge).expect("Edge must exist.")) + .collect::>(); + + let schema = GroupSchema::infer(node_attributes, edge_attributes); + + (group.clone(), schema) + }); + + let default_schema = GroupSchema::infer( + default_group + .0 + .into_iter() + .map(|node| medrecord.node_attributes(node).expect("Node must exist.")) + .collect::>(), + default_group + .1 + .into_iter() + .map(|edge| medrecord.edge_attributes(edge).expect("Edge must exist.")) + .collect::>(), + ); + + Self { + groups: group_schemas.collect(), + default: default_schema, + schema_type: SchemaType::Inferred, + } + } + + pub fn groups(&self) -> &HashMap { + &self.groups + } + + pub fn group(&self, group: &Group) -> Result<&GroupSchema, GraphError> { + self.groups + .get(group) + .ok_or(GraphError::SchemaError(format!( + "Group {} not found in schema.", + group + ))) + } + + pub fn default(&self) -> &GroupSchema { + &self.default + } + + pub fn schema_type(&self) -> &SchemaType { + &self.schema_type + } + + pub fn validate_node<'a>( + &self, + index: &'a NodeIndex, + attributes: &'a Attributes, + group: Option<&'a Group>, + ) -> Result<(), GraphError> { + let group_schema = group.and_then(|group| self.groups.get(group)); + + match group_schema { + Some(group_schema) => group_schema.validate_node(index, attributes), + None => self.default.validate_node(index, attributes), } } @@ -215,413 +531,1221 @@ impl Schema { ) -> Result<(), GraphError> { let group_schema = group.and_then(|group| self.groups.get(group)); - match (group_schema, &self.default, self.strict) { - (Some(group_schema), _, Some(true)) => { - group_schema.validate_edge(index, attributes, true)?; + match group_schema { + Some(group_schema) => group_schema.validate_edge(index, attributes), + None => self.default.validate_edge(index, attributes), + } + } - Ok(()) + pub(crate) fn update_node( + &mut self, + attributes: &Attributes, + group: Option<&Group>, + empty: bool, + ) { + match group { + Some(group) => { + self.groups + .entry(group.clone()) + .or_default() + .update_node(attributes, empty); } - (Some(group_schema), _, _) => { - group_schema.validate_edge(index, attributes, false)?; + None => self.default.update_node(attributes, empty), + } + } - Ok(()) + pub(crate) fn update_edge( + &mut self, + attributes: &Attributes, + group: Option<&Group>, + empty: bool, + ) { + match group { + Some(group) => { + self.groups + .entry(group.clone()) + .or_default() + .update_edge(attributes, empty); } - (_, Some(defalt_schema), Some(true)) => { - defalt_schema.validate_edge(index, attributes, true) + None => self.default.update_edge(attributes, empty), + } + } + + pub fn set_node_attribute( + &mut self, + attribute: &MedRecordAttribute, + data_type: DataType, + attribute_type: AttributeType, + group: Option<&Group>, + ) -> Result<(), GraphError> { + let attribute_data_type = AttributeDataType::new(data_type, attribute_type)?; + + match group { + Some(group) => { + let group_schema = self.groups.entry(group.clone()).or_default(); + group_schema + .nodes + .0 + .insert(attribute.clone(), attribute_data_type.clone()); } - (_, Some(default_schema), _) => default_schema.validate_edge(index, attributes, false), - (None, None, None) | (None, None, Some(false)) => Ok(()), + None => { + self.default + .nodes + .0 + .insert(attribute.clone(), attribute_data_type.clone()); + } + } - _ => Err(GraphError::SchemaError(format!( - "No schema provided for edge {} wit no group", - index - ))), + Ok(()) + } + + pub fn set_edge_attribute( + &mut self, + attribute: &MedRecordAttribute, + data_type: DataType, + attribute_type: AttributeType, + group: Option<&Group>, + ) -> Result<(), GraphError> { + let attribute_data_type = AttributeDataType::new(data_type, attribute_type)?; + + match group { + Some(group) => { + let group_schema = self.groups.entry(group.clone()).or_default(); + group_schema + .edges + .0 + .insert(attribute.clone(), attribute_data_type.clone()); + } + None => { + self.default + .edges + .0 + .insert(attribute.clone(), attribute_data_type.clone()); + } } + + Ok(()) } -} -impl Default for Schema { - fn default() -> Self { - Schema { - groups: HashMap::new(), - default: None, - strict: Some(false), + pub fn update_node_attribute( + &mut self, + attribute: &MedRecordAttribute, + data_type: DataType, + attribute_type: AttributeType, + group: Option<&Group>, + ) -> Result<(), GraphError> { + let attribute_data_type = AttributeDataType::new(data_type, attribute_type)?; + + match group { + Some(group) => { + let group_schema = self.groups.entry(group.clone()).or_default(); + group_schema + .nodes + .0 + .entry(attribute.clone()) + .and_modify(|value| value.merge(&attribute_data_type)) + .or_insert(attribute_data_type); + } + None => { + self.default + .nodes + .0 + .entry(attribute.clone()) + .and_modify(|value| value.merge(&attribute_data_type)) + .or_insert(attribute_data_type); + } + } + + Ok(()) + } + + pub fn update_edge_attribute( + &mut self, + attribute: &MedRecordAttribute, + data_type: DataType, + attribute_type: AttributeType, + group: Option<&Group>, + ) -> Result<(), GraphError> { + let attribute_data_type = AttributeDataType::new(data_type, attribute_type)?; + + match group { + Some(group) => { + let group_schema = self.groups.entry(group.clone()).or_default(); + group_schema + .edges + .0 + .entry(attribute.clone()) + .and_modify(|value| value.merge(&attribute_data_type)) + .or_insert(attribute_data_type); + } + None => { + self.default + .edges + .0 + .entry(attribute.clone()) + .and_modify(|value| value.merge(&attribute_data_type)) + .or_insert(attribute_data_type); + } + } + + Ok(()) + } + + pub fn remove_node_attribute(&mut self, attribute: &MedRecordAttribute, group: Option<&Group>) { + match group { + Some(group) => { + if let Some(group_schema) = self.groups.get_mut(group) { + group_schema.nodes.0.remove(attribute); + } + } + None => { + self.default.nodes.0.remove(attribute); + } + } + } + + pub fn remove_edge_attribute(&mut self, attribute: &MedRecordAttribute, group: Option<&Group>) { + match group { + Some(group) => { + if let Some(group_schema) = self.groups.get_mut(group) { + group_schema.edges.0.remove(attribute); + } + } + None => { + self.default.edges.0.remove(attribute); + } } } + + pub fn remove_group(&mut self, group: &Group) { + self.groups.remove(group); + } + + pub fn freeze(&mut self) { + self.schema_type = SchemaType::Provided; + } + + pub fn unfreeze(&mut self) { + self.schema_type = SchemaType::Inferred; + } } #[cfg(test)] mod test { - use super::{GroupSchema, Schema}; + use super::{AttributeDataType, GroupSchema}; use crate::{ - errors::GraphError, - medrecord::{Attributes, DataType, EdgeIndex, NodeIndex}, + medrecord::{ + schema::{AttributeSchema, AttributeSchemaKind}, + AttributeType, Attributes, DataType, Schema, SchemaType, + }, + MedRecord, }; use std::collections::HashMap; #[test] - fn test_validate_node_default_schema() { - let strict_schema = Schema { - groups: Default::default(), - default: Some(GroupSchema { - nodes: HashMap::from([("attribute".into(), DataType::Int.into())]), - edges: Default::default(), - strict: None, - }), - strict: Some(true), - }; - let second_strict_schema = Schema { - groups: Default::default(), - default: Some(GroupSchema { - nodes: HashMap::from([("attribute".into(), DataType::Int.into())]), - edges: Default::default(), - strict: Some(true), - }), - strict: Some(false), - }; - let non_strict_schema = Schema { - groups: Default::default(), - default: Some(GroupSchema { - nodes: HashMap::from([("attribute".into(), DataType::Int.into())]), - edges: Default::default(), - strict: None, - }), - strict: Some(false), - }; - let second_non_strict_schema = Schema { - groups: Default::default(), - default: Some(GroupSchema { - nodes: HashMap::from([("attribute".into(), DataType::Int.into())]), - edges: Default::default(), - strict: Some(false), - }), - strict: Some(true), - }; - - let attributes: Attributes = HashMap::from([("attribute".into(), 1.into())]); - let index: NodeIndex = 0.into(); - - assert!(strict_schema - .validate_node(&index, &attributes, None) - .is_ok()); - assert!(second_strict_schema - .validate_node(&index, &attributes, None) - .is_ok()); + fn test_attribute_type_infer() { + assert_eq!( + AttributeType::infer(&DataType::String), + AttributeType::Unstructured + ); + assert_eq!( + AttributeType::infer(&DataType::Int), + AttributeType::Continuous + ); + assert_eq!( + AttributeType::infer(&DataType::Float), + AttributeType::Continuous + ); + assert_eq!( + AttributeType::infer(&DataType::Bool), + AttributeType::Categorical + ); + assert_eq!( + AttributeType::infer(&DataType::DateTime), + AttributeType::Temporal + ); + assert_eq!( + AttributeType::infer(&DataType::Duration), + AttributeType::Continuous + ); + assert_eq!( + AttributeType::infer(&DataType::Null), + AttributeType::Unstructured + ); + assert_eq!( + AttributeType::infer(&DataType::Any), + AttributeType::Unstructured + ); + assert_eq!( + AttributeType::infer(&DataType::Union(( + Box::new(DataType::Int), + Box::new(DataType::Float) + ))), + AttributeType::Continuous + ); + assert_eq!( + AttributeType::infer(&DataType::Option(Box::new(DataType::Int))), + AttributeType::Continuous + ); + } - let attributes: Attributes = HashMap::from([("attribute".into(), "1".into())]); + #[test] + fn test_attribute_type_merge() { + assert_eq!( + AttributeType::Categorical.merge(&AttributeType::Categorical), + AttributeType::Categorical + ); + assert_eq!( + AttributeType::Continuous.merge(&AttributeType::Continuous), + AttributeType::Continuous + ); + assert_eq!( + AttributeType::Temporal.merge(&AttributeType::Temporal), + AttributeType::Temporal + ); + assert_eq!( + AttributeType::Categorical.merge(&AttributeType::Continuous), + AttributeType::Unstructured + ); + assert_eq!( + AttributeType::Categorical.merge(&AttributeType::Temporal), + AttributeType::Unstructured + ); + assert_eq!( + AttributeType::Continuous.merge(&AttributeType::Temporal), + AttributeType::Unstructured + ); + } - assert!(strict_schema - .validate_node(&index, &attributes, None) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - assert!(second_strict_schema - .validate_node(&index, &attributes, None) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); + #[test] + fn test_data_type_merge() { + assert_eq!(DataType::Int.merge(&DataType::Int), DataType::Int); + assert_eq!( + DataType::Int.merge(&DataType::Float), + DataType::Union((Box::new(DataType::Int), Box::new(DataType::Float))) + ); + assert_eq!( + DataType::Int.merge(&DataType::Null), + DataType::Option(Box::new(DataType::Int)) + ); + assert_eq!( + DataType::Null.merge(&DataType::Int), + DataType::Option(Box::new(DataType::Int)) + ); + assert_eq!(DataType::Null.merge(&DataType::Null), DataType::Null); + assert_eq!(DataType::Int.merge(&DataType::Any), DataType::Any); + assert_eq!(DataType::Any.merge(&DataType::Int), DataType::Any); + } - let attributes: Attributes = - HashMap::from([("attribute".into(), 1.into()), ("extra".into(), 1.into())]); - - assert!(strict_schema - .validate_node(&index, &attributes, None) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - assert!(second_strict_schema - .validate_node(&index, &attributes, None) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - assert!(non_strict_schema - .validate_node(&index, &attributes, None) - .is_ok()); - assert!(second_non_strict_schema - .validate_node(&index, &attributes, None) - .is_ok()); + #[test] + fn test_attribute_data_type_new() { + assert!(AttributeDataType::new(DataType::String, AttributeType::Categorical).is_ok()); + assert!(AttributeDataType::new(DataType::String, AttributeType::Continuous).is_err()); + assert!(AttributeDataType::new(DataType::String, AttributeType::Temporal).is_err()); + assert!(AttributeDataType::new(DataType::String, AttributeType::Unstructured).is_ok()); + + assert!(AttributeDataType::new(DataType::Int, AttributeType::Categorical).is_ok()); + assert!(AttributeDataType::new(DataType::Int, AttributeType::Continuous).is_ok()); + assert!(AttributeDataType::new(DataType::Int, AttributeType::Temporal).is_err()); + assert!(AttributeDataType::new(DataType::Int, AttributeType::Unstructured).is_ok()); + + assert!(AttributeDataType::new(DataType::Float, AttributeType::Categorical).is_ok()); + assert!(AttributeDataType::new(DataType::Float, AttributeType::Continuous).is_ok()); + assert!(AttributeDataType::new(DataType::Float, AttributeType::Temporal).is_err()); + assert!(AttributeDataType::new(DataType::Float, AttributeType::Unstructured).is_ok()); + + assert!(AttributeDataType::new(DataType::Bool, AttributeType::Categorical).is_ok()); + assert!(AttributeDataType::new(DataType::Bool, AttributeType::Continuous).is_err()); + assert!(AttributeDataType::new(DataType::Bool, AttributeType::Temporal).is_err()); + assert!(AttributeDataType::new(DataType::Bool, AttributeType::Unstructured).is_ok()); + + assert!(AttributeDataType::new(DataType::DateTime, AttributeType::Categorical).is_ok()); + assert!(AttributeDataType::new(DataType::DateTime, AttributeType::Continuous).is_err()); + assert!(AttributeDataType::new(DataType::DateTime, AttributeType::Temporal).is_ok()); + assert!(AttributeDataType::new(DataType::DateTime, AttributeType::Unstructured).is_ok()); + + assert!(AttributeDataType::new(DataType::Duration, AttributeType::Categorical).is_ok()); + assert!(AttributeDataType::new(DataType::Duration, AttributeType::Continuous).is_err()); + assert!(AttributeDataType::new(DataType::Duration, AttributeType::Temporal).is_ok()); + assert!(AttributeDataType::new(DataType::Duration, AttributeType::Unstructured).is_ok()); + + assert!(AttributeDataType::new(DataType::Null, AttributeType::Categorical).is_ok()); + assert!(AttributeDataType::new(DataType::Null, AttributeType::Continuous).is_ok()); + assert!(AttributeDataType::new(DataType::Null, AttributeType::Temporal).is_ok()); + assert!(AttributeDataType::new(DataType::Null, AttributeType::Unstructured).is_ok()); + + assert!(AttributeDataType::new(DataType::Any, AttributeType::Categorical).is_ok()); + assert!(AttributeDataType::new(DataType::Any, AttributeType::Continuous).is_err()); + assert!(AttributeDataType::new(DataType::Any, AttributeType::Temporal).is_err()); + assert!(AttributeDataType::new(DataType::Any, AttributeType::Unstructured).is_ok()); + + assert!(AttributeDataType::new( + DataType::Option(Box::new(DataType::Int)), + AttributeType::Categorical + ) + .is_ok()); + assert!(AttributeDataType::new( + DataType::Option(Box::new(DataType::Int)), + AttributeType::Continuous + ) + .is_ok()); + assert!(AttributeDataType::new( + DataType::Option(Box::new(DataType::Int)), + AttributeType::Temporal + ) + .is_err()); + assert!(AttributeDataType::new( + DataType::Option(Box::new(DataType::Int)), + AttributeType::Unstructured + ) + .is_ok()); + + assert!(AttributeDataType::new( + DataType::Union((Box::new(DataType::Int), Box::new(DataType::Float))), + AttributeType::Categorical + ) + .is_ok()); + assert!(AttributeDataType::new( + DataType::Union((Box::new(DataType::Int), Box::new(DataType::Float))), + AttributeType::Continuous + ) + .is_ok()); + assert!(AttributeDataType::new( + DataType::Union((Box::new(DataType::Int), Box::new(DataType::Float))), + AttributeType::Temporal + ) + .is_err()); + assert!(AttributeDataType::new( + DataType::Union((Box::new(DataType::Int), Box::new(DataType::Float))), + AttributeType::Unstructured + ) + .is_ok()); } #[test] - fn test_invalid_validate_node_default_schema() { - let schema = Schema { - groups: Default::default(), - default: None, - strict: Some(true), - }; + fn test_attribute_data_type_data_type() { + let attribute_data_type = AttributeDataType::new(DataType::Int, AttributeType::Categorical) + .expect("AttributeType was infered from DataType."); - let attributes: Attributes = HashMap::from([("attribute".into(), 1.into())]); - let index: NodeIndex = 0.into(); + assert_eq!(attribute_data_type.data_type(), &DataType::Int); + } - assert!(schema - .validate_node(&index, &attributes, None) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - } - - #[test] - fn test_validate_node_group_schema() { - let strict_schema = Schema { - groups: HashMap::from([( - "group".into(), - GroupSchema { - nodes: HashMap::from([("attribute".into(), DataType::Int.into())]), - edges: Default::default(), - strict: None, - }, - )]), - default: None, - strict: Some(true), - }; - let second_strict_schema = Schema { - groups: HashMap::from([( - "group".into(), - GroupSchema { - nodes: HashMap::from([("attribute".into(), DataType::Int.into())]), - edges: Default::default(), - strict: Some(true), - }, - )]), - default: None, - strict: Some(false), - }; - let non_strict_schema = Schema { - groups: HashMap::from([( - "group".into(), - GroupSchema { - nodes: HashMap::from([("attribute".into(), DataType::Int.into())]), - edges: Default::default(), - strict: None, - }, - )]), - default: None, - strict: Some(false), - }; - let second_non_strict_schema = Schema { - groups: HashMap::from([( - "group".into(), - GroupSchema { - nodes: HashMap::from([("attribute".into(), DataType::Int.into())]), - edges: Default::default(), - strict: Some(false), - }, - )]), - default: None, - strict: Some(true), - }; - - let attributes: Attributes = HashMap::from([("attribute".into(), 1.into())]); - let index: NodeIndex = 0.into(); - - assert!(strict_schema - .validate_node(&index, &attributes, Some(&"group".into())) - .is_ok()); - assert!(second_strict_schema - .validate_node(&index, &attributes, Some(&"group".into())) - .is_ok()); + #[test] + fn test_attribute_data_type_attribute_type() { + let attribute_data_type = AttributeDataType::new(DataType::Int, AttributeType::Categorical) + .expect("AttributeType was infered from DataType."); + + assert_eq!( + attribute_data_type.attribute_type(), + &AttributeType::Categorical + ); + } - let attributes: Attributes = HashMap::from([("attribute".into(), "1".into())]); + #[test] + fn test_attribute_data_type_merge() { + let mut attribute_data_type = + AttributeDataType::new(DataType::Int, AttributeType::Categorical) + .expect("AttributeType was infered from DataType."); + + attribute_data_type.merge( + &AttributeDataType::new(DataType::Float, AttributeType::Continuous) + .expect("AttributeType was infered from DataType."), + ); + + assert_eq!( + attribute_data_type.data_type(), + &DataType::Union((Box::new(DataType::Int), Box::new(DataType::Float))) + ); + assert_eq!( + attribute_data_type.attribute_type(), + &AttributeType::Unstructured + ); + } - assert!(strict_schema - .validate_node(&index, &attributes, Some(&"group".into())) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - assert!(second_strict_schema - .validate_node(&index, &attributes, Some(&"group".into())) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); + #[test] + fn test_attribute_data_type_from_data_type() { + let attribute_data_type: AttributeDataType = DataType::Int.into(); + + assert_eq!(attribute_data_type.data_type(), &DataType::Int); + assert_eq!( + attribute_data_type.attribute_type(), + &AttributeType::Continuous + ); + } - let attributes: Attributes = - HashMap::from([("attribute".into(), 1.into()), ("extra".into(), 1.into())]); - - assert!(strict_schema - .validate_node(&index, &attributes, Some(&"group".into())) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - assert!(second_strict_schema - .validate_node(&index, &attributes, Some(&"group".into())) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - assert!(non_strict_schema - .validate_node(&index, &attributes, Some(&"group".into())) - .is_ok()); - assert!(second_non_strict_schema - .validate_node(&index, &attributes, Some(&"group".into())) - .is_ok()); + #[test] + fn test_attribute_data_type_from_tuple() { + let attribute_data_type: AttributeDataType = + (DataType::Int, AttributeType::Categorical).into(); + + assert_eq!(attribute_data_type.data_type(), &DataType::Int); + assert_eq!( + attribute_data_type.attribute_type(), + &AttributeType::Categorical + ); + } - // Checking schema of non existing group should fail because no default schema exists - assert!(strict_schema - .validate_node(&index, &attributes, Some(&"test".into())) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - } - - #[test] - fn test_validate_edge_default_schema() { - let strict_schema = Schema { - groups: Default::default(), - default: Some(GroupSchema { - nodes: Default::default(), - edges: HashMap::from([("attribute".into(), DataType::Int.into())]), - strict: None, - }), - strict: Some(true), - }; - let second_strict_schema = Schema { - groups: Default::default(), - default: Some(GroupSchema { - nodes: Default::default(), - edges: HashMap::from([("attribute".into(), DataType::Int.into())]), - strict: Some(true), - }), - strict: Some(false), - }; - let non_strict_schema = Schema { - groups: Default::default(), - default: Some(GroupSchema { - nodes: Default::default(), - edges: HashMap::from([("attribute".into(), DataType::Int.into())]), - strict: None, - }), - strict: Some(false), - }; - let second_non_strict_schema = Schema { - groups: Default::default(), - default: Some(GroupSchema { - nodes: Default::default(), - edges: HashMap::from([("attribute".into(), DataType::Int.into())]), - strict: Some(false), - }), - strict: Some(true), - }; - - let attributes: Attributes = HashMap::from([("attribute".into(), 1.into())]); - let index: EdgeIndex = 0; - - assert!(strict_schema - .validate_edge(&index, &attributes, None) - .is_ok()); - assert!(second_strict_schema - .validate_edge(&index, &attributes, None) + #[test] + fn test_attribute_schema_kind_error_message() { + let index = 0; + let key = "key"; + let data_type = DataType::Int; + + assert_eq!( + AttributeSchemaKind::Node(&(index.into())).error_message(&(key.into()), &data_type), + "Attribute key of type Int not found on node with index 0" + ); + assert_eq!( + AttributeSchemaKind::Edge(&(index as u32)).error_message(&(key.into()), &data_type), + "Attribute key of type Int not found on edge with index 0" + ); + } + + #[test] + fn test_attribute_schema_kind_error_message_expected() { + let index = 0; + let key = "key"; + let data_type = DataType::Int; + let expected_data_type = DataType::Float; + + assert_eq!( + AttributeSchemaKind::Node(&(index.into())).error_message_expected( + &(key.into()), + &data_type, + &expected_data_type + ), + "Attribute key of node with index 0 is of type Int. Expected Float." + ); + assert_eq!( + AttributeSchemaKind::Edge(&(index as u32)).error_message_expected( + &(key.into()), + &data_type, + &expected_data_type + ), + "Attribute key of node with index 0 is of type Int. Expected Float." + ); + } + + #[test] + fn test_attribute_schema_kind_error_message_too_many() { + let index = 0; + let attributes = vec!["key1".to_string(), "key2".to_string()]; + + assert_eq!( + AttributeSchemaKind::Node(&(index.into())).error_message_too_many(attributes.clone()), + "Attributes [key1, key2] of node with index 0 do not exist in schema." + ); + assert_eq!( + AttributeSchemaKind::Edge(&(index as u32)).error_message_too_many(attributes), + "Attributes [key1, key2] of edge with index 0 do not exist in schema." + ); + } + + #[test] + fn test_group_schema_nodes() { + let nodes = AttributeSchema::new( + vec![ + ( + "key1".into(), + AttributeDataType::new(DataType::Int, AttributeType::Categorical) + .expect("AttributeType was infered from DataType."), + ), + ( + "key2".into(), + AttributeDataType::new(DataType::Float, AttributeType::Continuous) + .expect("AttributeType was infered from DataType."), + ), + ] + .into_iter() + .collect(), + ); + + let group_schema = GroupSchema::new(nodes.clone(), AttributeSchema::default()); + + assert_eq!(group_schema.nodes(), &nodes.0); + } + + #[test] + fn test_group_schema_edges() { + let edges = AttributeSchema::new( + vec![ + ( + "key1".into(), + AttributeDataType::new(DataType::Int, AttributeType::Categorical) + .expect("AttributeType was infered from DataType."), + ), + ( + "key2".into(), + AttributeDataType::new(DataType::Float, AttributeType::Continuous) + .expect("AttributeType was infered from DataType."), + ), + ] + .into_iter() + .collect(), + ); + + let group_schema = GroupSchema::new(AttributeSchema::default(), edges.clone()); + + assert_eq!(group_schema.edges(), &edges.0); + } + + #[test] + fn test_group_schema_validate_attribute_schema() { + let attributes: Attributes = vec![("key1".into(), 0.into()), ("key2".into(), 0.0.into())] + .into_iter() + .collect(); + + let attribute_schema = AttributeSchema::new( + vec![ + ( + "key1".into(), + AttributeDataType::new(DataType::Int, AttributeType::Categorical) + .expect("AttributeType was infered from DataType."), + ), + ( + "key2".into(), + AttributeDataType::new(DataType::Float, AttributeType::Continuous) + .expect("AttributeType was infered from DataType."), + ), + ] + .into_iter() + .collect(), + ); + + assert!(attribute_schema + .validate(&attributes, AttributeSchemaKind::Node(&0.into())) .is_ok()); - let attributes: Attributes = HashMap::from([("attribute".into(), "1".into())]); + let attributes: Attributes = vec![("key1".into(), 0.0.into()), ("key2".into(), 0.into())] + .into_iter() + .collect(); + + assert!(attribute_schema + .validate(&attributes, AttributeSchemaKind::Node(&0.into())) + .is_err_and(|error| { matches!(error, crate::errors::GraphError::SchemaError(_)) })); + } + + #[test] + fn test_group_schema_validate_node() { + let nodes = AttributeSchema::new( + vec![ + ( + "key1".into(), + AttributeDataType::new(DataType::Int, AttributeType::Categorical) + .expect("AttributeType was infered from DataType."), + ), + ( + "key2".into(), + AttributeDataType::new(DataType::Float, AttributeType::Continuous) + .expect("AttributeType was infered from DataType."), + ), + ] + .into_iter() + .collect(), + ); + + let group_schema = GroupSchema::new(nodes, AttributeSchema::default()); + + let attributes: Attributes = vec![("key1".into(), 0.into()), ("key2".into(), 0.0.into())] + .into_iter() + .collect(); + + assert!(group_schema.validate_node(&0.into(), &attributes).is_ok()); + + let attributes: Attributes = vec![("key1".into(), 0.0.into()), ("key2".into(), 0.into())] + .into_iter() + .collect(); + + assert!(group_schema + .validate_node(&0.into(), &attributes) + .is_err_and(|error| { matches!(error, crate::errors::GraphError::SchemaError(_)) })); + } - assert!(strict_schema - .validate_edge(&index, &attributes, None) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - assert!(second_strict_schema - .validate_edge(&index, &attributes, None) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); + #[test] + fn test_group_schema_validate_edge() { + let edges = AttributeSchema::new( + vec![ + ( + "key1".into(), + AttributeDataType::new(DataType::Int, AttributeType::Categorical) + .expect("AttributeType was infered from DataType."), + ), + ( + "key2".into(), + AttributeDataType::new(DataType::Float, AttributeType::Continuous) + .expect("AttributeType was infered from DataType."), + ), + ] + .into_iter() + .collect(), + ); + + let group_schema = GroupSchema::new(AttributeSchema::default(), edges); + + let attributes: Attributes = vec![("key1".into(), 0.into()), ("key2".into(), 0.0.into())] + .into_iter() + .collect(); + + assert!(group_schema.validate_edge(&0, &attributes).is_ok()); + + let attributes: Attributes = vec![("key1".into(), 0.0.into()), ("key2".into(), 0.into())] + .into_iter() + .collect(); + + assert!(group_schema + .validate_edge(&0, &attributes) + .is_err_and(|error| { matches!(error, crate::errors::GraphError::SchemaError(_)) })); + } + #[test] + fn test_update_attribute_schema() { + let mut schema = AttributeSchema::default(); let attributes: Attributes = - HashMap::from([("attribute".into(), 1.into()), ("extra".into(), 1.into())]); - - assert!(strict_schema - .validate_edge(&index, &attributes, None) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - assert!(second_strict_schema - .validate_edge(&index, &attributes, None) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - assert!(non_strict_schema - .validate_edge(&index, &attributes, None) - .is_ok()); - assert!(second_non_strict_schema - .validate_edge(&index, &attributes, None) - .is_ok()); + vec![("key1".into(), 0.into()), ("key2".into(), "test".into())] + .into_iter() + .collect(); + + schema.update(&attributes, true); + + assert_eq!(schema.0.len(), 2); + assert_eq!( + schema.0.get(&"key1".into()).unwrap().data_type(), + &DataType::Int + ); + assert_eq!( + schema.0.get(&"key2".into()).unwrap().data_type(), + &DataType::String + ); + + let new_attributes: Attributes = + vec![("key1".into(), 0.5.into()), ("key3".into(), true.into())] + .into_iter() + .collect(); + + schema.update(&new_attributes, false); + + assert_eq!(schema.0.len(), 3); + assert_eq!( + schema.0.get(&"key1".into()).unwrap().data_type(), + &DataType::Union((Box::new(DataType::Int), Box::new(DataType::Float))) + ); + assert_eq!( + schema.0.get(&"key2".into()).unwrap().data_type(), + &DataType::Option(Box::new(DataType::String)) + ); + assert_eq!( + schema.0.get(&"key3".into()).unwrap().data_type(), + &DataType::Option(Box::new(DataType::Bool)) + ); + } + + #[test] + fn test_infer_attribute_schema() { + let attributes1: Attributes = + vec![("key1".into(), 0.into()), ("key2".into(), "test".into())] + .into_iter() + .collect(); + + let attributes2: Attributes = vec![("key1".into(), 1.into()), ("key3".into(), true.into())] + .into_iter() + .collect(); + + let schema = AttributeSchema::infer(vec![&attributes1, &attributes2]); + + assert_eq!(schema.0.len(), 3); + assert_eq!( + schema.0.get(&"key1".into()).unwrap().data_type(), + &DataType::Int + ); + assert_eq!( + schema.0.get(&"key2".into()).unwrap().data_type(), + &DataType::Option(Box::new(DataType::String)) + ); + assert_eq!( + schema.0.get(&"key3".into()).unwrap().data_type(), + &DataType::Option(Box::new(DataType::Bool)) + ); + } + + #[test] + fn test_group_schema_infer() { + let node_attributes1: Attributes = + vec![("key1".into(), 0.into()), ("key2".into(), "test".into())] + .into_iter() + .collect(); + + let node_attributes2: Attributes = + vec![("key1".into(), 1.into()), ("key3".into(), true.into())] + .into_iter() + .collect(); + + let edge_attributes: Attributes = + vec![("key4".into(), 0.5.into()), ("key5".into(), "edge".into())] + .into_iter() + .collect(); + + let group_schema = GroupSchema::infer( + vec![&node_attributes1, &node_attributes2], + vec![&edge_attributes], + ); + + assert_eq!(group_schema.nodes().len(), 3); + assert_eq!(group_schema.edges().len(), 2); + + assert_eq!( + group_schema + .nodes() + .get(&"key1".into()) + .unwrap() + .data_type(), + &DataType::Int + ); + assert_eq!( + group_schema + .nodes() + .get(&"key2".into()) + .unwrap() + .data_type(), + &DataType::Option(Box::new(DataType::String)) + ); + assert_eq!( + group_schema + .nodes() + .get(&"key3".into()) + .unwrap() + .data_type(), + &DataType::Option(Box::new(DataType::Bool)) + ); + + assert_eq!( + group_schema + .edges() + .get(&"key4".into()) + .unwrap() + .data_type(), + &DataType::Float + ); + assert_eq!( + group_schema + .edges() + .get(&"key5".into()) + .unwrap() + .data_type(), + &DataType::String + ); + } + + #[test] + fn test_group_schema_update_node() { + let mut group_schema = GroupSchema::default(); + let attributes = Attributes::from([("key1".into(), 0.into()), ("key2".into(), 0.0.into())]); + + group_schema.update_node(&attributes, true); + + assert_eq!(group_schema.nodes().len(), 2); + assert_eq!( + group_schema + .nodes() + .get(&"key1".into()) + .unwrap() + .data_type(), + &DataType::Int + ); + assert_eq!( + group_schema + .nodes() + .get(&"key2".into()) + .unwrap() + .data_type(), + &DataType::Float + ); + } + + #[test] + fn test_group_schema_update_edge() { + let mut group_schema = GroupSchema::default(); + let attributes = + Attributes::from([("key3".into(), true.into()), ("key4".into(), "test".into())]); + + group_schema.update_edge(&attributes, true); + + assert_eq!(group_schema.edges().len(), 2); + assert_eq!( + group_schema + .edges() + .get(&"key3".into()) + .unwrap() + .data_type(), + &DataType::Bool + ); + assert_eq!( + group_schema + .edges() + .get(&"key4".into()) + .unwrap() + .data_type(), + &DataType::String + ); + } + + #[test] + fn test_schema_infer() { + let mut medrecord = MedRecord::new(); + medrecord + .add_node(0.into(), Attributes::from([("key1".into(), 0.into())])) + .unwrap(); + medrecord + .add_node(1.into(), Attributes::from([("key2".into(), 0.0.into())])) + .unwrap(); + medrecord + .add_edge( + 0.into(), + 1.into(), + Attributes::from([("key3".into(), true.into())]), + ) + .unwrap(); + + let schema = Schema::infer(&medrecord); + + assert_eq!(schema.default().nodes().len(), 2); + assert_eq!(schema.default().edges().len(), 1); + } + + #[test] + fn test_schema_groups() { + let schema = Schema::new_inferred( + vec![("group1".into(), GroupSchema::default())] + .into_iter() + .collect(), + GroupSchema::default(), + ); + assert_eq!(schema.groups().len(), 1); + assert!(schema.groups().contains_key(&"group1".into())); } #[test] - fn test_invalid_validate_edge_default_schema() { - let schema = Schema { - groups: Default::default(), - default: None, - strict: Some(true), - }; + fn test_schema_group() { + let schema = Schema::new_inferred( + vec![("group1".into(), GroupSchema::default())] + .into_iter() + .collect(), + GroupSchema::default(), + ); + assert!(schema.group(&"group1".into()).is_ok()); + assert!(schema.group(&"non_existent".into()).is_err()); + } - let attributes: Attributes = HashMap::from([("attribute".into(), 1.into())]); - let index: EdgeIndex = 0; + #[test] + fn test_schema_default() { + let default_schema = GroupSchema::default(); + let schema = Schema::new_inferred(HashMap::new(), default_schema.clone()); + assert_eq!(schema.default(), &default_schema); + } + + #[test] + fn test_schema_schema_type() { + let schema = Schema::new_inferred(HashMap::new(), GroupSchema::default()); + assert_eq!(schema.schema_type(), &SchemaType::Inferred); + } + #[test] + fn test_schema_validate_node() { + let mut schema = Schema::new_inferred( + HashMap::new(), + GroupSchema::new(AttributeSchema::default(), AttributeSchema::default()), + ); + schema + .set_node_attribute( + &"key1".into(), + DataType::Int, + AttributeType::Continuous, + None, + ) + .unwrap(); + + let attributes = Attributes::from([("key1".into(), 0.into())]); + assert!(schema.validate_node(&0.into(), &attributes, None).is_ok()); + + let invalid_attributes = Attributes::from([("key1".into(), "invalid".into())]); + assert!(schema + .validate_node(&0.into(), &invalid_attributes, None) + .is_err()); + } + + #[test] + fn test_schema_validate_edge() { + let mut schema = Schema::new_inferred( + HashMap::new(), + GroupSchema::new(AttributeSchema::default(), AttributeSchema::default()), + ); + schema + .set_edge_attribute( + &"key1".into(), + DataType::Bool, + AttributeType::Categorical, + None, + ) + .unwrap(); + + let attributes = Attributes::from([("key1".into(), true.into())]); + assert!(schema.validate_edge(&0, &attributes, None).is_ok()); + + let invalid_attributes = Attributes::from([("key1".into(), 0.into())]); + assert!(schema.validate_edge(&0, &invalid_attributes, None).is_err()); + } + + #[test] + fn test_schema_update_node() { + let mut schema = Schema::new_inferred( + HashMap::new(), + GroupSchema::new(AttributeSchema::default(), AttributeSchema::default()), + ); + let attributes = Attributes::from([("key1".into(), 0.into()), ("key2".into(), 0.0.into())]); + + schema.update_node(&attributes, None, true); + + assert_eq!(schema.default().nodes().len(), 2); + assert_eq!( + schema + .default() + .nodes() + .get(&"key1".into()) + .unwrap() + .data_type(), + &DataType::Int + ); + assert_eq!( + schema + .default() + .nodes() + .get(&"key2".into()) + .unwrap() + .data_type(), + &DataType::Float + ); + } + + #[test] + fn test_schema_update_edge() { + let mut schema = Schema::new_inferred( + HashMap::new(), + GroupSchema::new(AttributeSchema::default(), AttributeSchema::default()), + ); + let attributes = + Attributes::from([("key3".into(), true.into()), ("key4".into(), "test".into())]); + + schema.update_edge(&attributes, None, true); + + assert_eq!(schema.default().edges().len(), 2); + assert_eq!( + schema + .default() + .edges() + .get(&"key3".into()) + .unwrap() + .data_type(), + &DataType::Bool + ); + assert_eq!( + schema + .default() + .edges() + .get(&"key4".into()) + .unwrap() + .data_type(), + &DataType::String + ); + } + + #[test] + fn test_schema_set_node_attribute() { + let mut schema = Schema::new_inferred(HashMap::new(), GroupSchema::default()); assert!(schema - .validate_edge(&index, &attributes, None) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - } - - #[test] - fn test_validate_edge_group_schema() { - let strict_schema = Schema { - groups: HashMap::from([( - "group".into(), - GroupSchema { - nodes: Default::default(), - edges: HashMap::from([("attribute".into(), DataType::Int.into())]), - strict: None, - }, - )]), - default: None, - strict: Some(true), - }; - let second_strict_schema = Schema { - groups: HashMap::from([( - "group".into(), - GroupSchema { - nodes: Default::default(), - edges: HashMap::from([("attribute".into(), DataType::Int.into())]), - strict: Some(true), - }, - )]), - default: None, - strict: Some(false), - }; - let non_strict_schema = Schema { - groups: HashMap::from([( - "group".into(), - GroupSchema { - nodes: Default::default(), - edges: HashMap::from([("attribute".into(), DataType::Int.into())]), - strict: None, - }, - )]), - default: None, - strict: Some(false), - }; - let second_non_strict_schema = Schema { - groups: HashMap::from([( - "group".into(), - GroupSchema { - nodes: Default::default(), - edges: HashMap::from([("attribute".into(), DataType::Int.into())]), - strict: Some(false), - }, - )]), - default: None, - strict: Some(true), - }; - - let attributes: Attributes = HashMap::from([("attribute".into(), 1.into())]); - let index: EdgeIndex = 0; - - assert!(strict_schema - .validate_edge(&index, &attributes, Some(&"group".into())) + .set_node_attribute( + &"key1".into(), + DataType::Int, + AttributeType::Continuous, + None + ) .is_ok()); - assert!(second_strict_schema - .validate_edge(&index, &attributes, Some(&"group".into())) + assert_eq!( + schema + .default() + .nodes() + .get(&"key1".into()) + .unwrap() + .data_type(), + &DataType::Int + ); + assert!(schema + .set_node_attribute( + &"key1".into(), + DataType::Float, + AttributeType::Continuous, + None + ) .is_ok()); + assert_eq!( + schema + .default() + .nodes() + .get(&"key1".into()) + .unwrap() + .data_type(), + &DataType::Float + ); + } - let attributes: Attributes = HashMap::from([("attribute".into(), "1".into())]); - - assert!(strict_schema - .validate_edge(&index, &attributes, Some(&"group".into())) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - assert!(second_strict_schema - .validate_edge(&index, &attributes, Some(&"group".into())) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); + #[test] + fn test_schema_set_edge_attribute() { + let mut schema = Schema::new_inferred(HashMap::new(), GroupSchema::default()); + assert!(schema + .set_edge_attribute( + &"key1".into(), + DataType::Bool, + AttributeType::Categorical, + None + ) + .is_ok()); + assert_eq!( + schema + .default() + .edges() + .get(&"key1".into()) + .unwrap() + .data_type(), + &DataType::Bool + ); + assert!(schema + .set_edge_attribute( + &"key1".into(), + DataType::Float, + AttributeType::Continuous, + None + ) + .is_ok()); + assert_eq!( + schema + .default() + .edges() + .get(&"key1".into()) + .unwrap() + .data_type(), + &DataType::Float + ); + } - let attributes: Attributes = - HashMap::from([("attribute".into(), 1.into()), ("extra".into(), 1.into())]); - - assert!(strict_schema - .validate_edge(&index, &attributes, Some(&"group".into())) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - assert!(second_strict_schema - .validate_edge(&index, &attributes, Some(&"group".into())) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); - assert!(non_strict_schema - .validate_edge(&index, &attributes, Some(&"group".into())) + #[test] + fn test_schema_update_node_attribute() { + let mut schema = Schema::new_inferred(HashMap::new(), GroupSchema::default()); + schema + .set_node_attribute( + &"key1".into(), + DataType::Int, + AttributeType::Continuous, + None, + ) + .unwrap(); + assert!(schema + .update_node_attribute( + &"key1".into(), + DataType::Float, + AttributeType::Continuous, + None + ) .is_ok()); - assert!(second_non_strict_schema - .validate_edge(&index, &attributes, Some(&"group".into())) + assert_eq!( + schema + .default() + .nodes() + .get(&"key1".into()) + .unwrap() + .data_type(), + &DataType::Union((Box::new(DataType::Int), Box::new(DataType::Float))) + ); + } + + #[test] + fn test_schema_update_edge_attribute() { + let mut schema = Schema::new_inferred(HashMap::new(), GroupSchema::default()); + schema + .set_edge_attribute( + &"key1".into(), + DataType::Bool, + AttributeType::Categorical, + None, + ) + .unwrap(); + assert!(schema + .update_edge_attribute( + &"key1".into(), + DataType::String, + AttributeType::Unstructured, + None + ) .is_ok()); + assert_eq!( + schema + .default() + .edges() + .get(&"key1".into()) + .unwrap() + .data_type(), + &DataType::Union((Box::new(DataType::Bool), Box::new(DataType::String))) + ); + } + + #[test] + fn test_schema_remove_node_attribute() { + let mut schema = Schema::new_inferred(HashMap::new(), GroupSchema::default()); + schema + .set_node_attribute( + &"key1".into(), + DataType::Int, + AttributeType::Continuous, + None, + ) + .unwrap(); + schema.remove_node_attribute(&"key1".into(), None); + assert!(!schema.default().nodes().contains_key(&"key1".into())); + } + + #[test] + fn test_schema_remove_edge_attribute() { + let mut schema = Schema::new_inferred(HashMap::new(), GroupSchema::default()); + schema + .set_edge_attribute( + &"key1".into(), + DataType::Bool, + AttributeType::Categorical, + None, + ) + .unwrap(); + schema.remove_edge_attribute(&"key1".into(), None); + assert!(!schema.default().edges().contains_key(&"key1".into())); + } + + #[test] + fn test_schema_remove_group() { + let mut schema = Schema::new_inferred( + vec![("group1".into(), GroupSchema::default())] + .into_iter() + .collect(), + GroupSchema::default(), + ); + schema.remove_group(&"group1".into()); + assert!(!schema.groups().contains_key(&"group1".into())); + } + + #[test] + fn test_schema_freeze_unfreeze() { + let mut schema = Schema::new_inferred(HashMap::new(), GroupSchema::default()); + assert_eq!(schema.schema_type(), &SchemaType::Inferred); + + schema.freeze(); + assert_eq!(schema.schema_type(), &SchemaType::Provided); - // Checking schema of non existing group should fail because no default schema exists - assert!(strict_schema - .validate_edge(&index, &attributes, Some(&"test".into())) - .is_err_and(|e| matches!(e, GraphError::SchemaError(_)))); + schema.unfreeze(); + assert_eq!(schema.schema_type(), &SchemaType::Inferred); } } diff --git a/medmodels/_medmodels.pyi b/medmodels/_medmodels.pyi index c6e672eb..f56749d1 100644 --- a/medmodels/_medmodels.pyi +++ b/medmodels/_medmodels.pyi @@ -55,41 +55,94 @@ class PyAttributeType(Enum): Categorical = ... Continuous = ... Temporal = ... + Unstructured = ... + + @staticmethod + def infer(data_type: PyDataType) -> PyAttributeType: ... class PyAttributeDataType: data_type: PyDataType - attribute_type: Optional[PyAttributeType] + attribute_type: PyAttributeType def __init__( - self, data_type: PyDataType, attribute_type: Optional[PyAttributeType] + self, data_type: PyDataType, attribute_type: PyAttributeType ) -> None: ... class PyGroupSchema: nodes: Dict[MedRecordAttribute, PyAttributeDataType] edges: Dict[MedRecordAttribute, PyAttributeDataType] - strict: Optional[bool] def __init__( self, *, nodes: Dict[MedRecordAttribute, PyAttributeDataType], edges: Dict[MedRecordAttribute, PyAttributeDataType], - strict: Optional[bool] = None, ) -> None: ... + def validate_node(self, index: NodeIndex, attributes: Attributes) -> None: ... + def validate_edge(self, index: EdgeIndex, attributes: Attributes) -> None: ... + +class PySchemaType(Enum): + Provided = ... + Inferred = ... class PySchema: groups: List[Group] - default: Optional[PyGroupSchema] - strict: Optional[bool] + default: PyGroupSchema + schema_type: PySchemaType def __init__( self, *, groups: Dict[Group, PyGroupSchema], - default: Optional[PyGroupSchema] = None, - strict: Optional[bool] = None, + default: PyGroupSchema, + schema_type: PySchemaType = ..., ) -> None: ... + @staticmethod + def infer(medrecord: PyMedRecord) -> PySchema: ... def group(self, group: Group) -> PyGroupSchema: ... + def validate_node( + self, index: NodeIndex, attributes: Attributes, group: Optional[Group] = None + ) -> None: ... + def validate_edge( + self, index: EdgeIndex, attributes: Attributes, group: Optional[Group] = None + ) -> None: ... + def set_node_attribute( + self, + attribute: MedRecordAttribute, + data_type: PyDataType, + attribute_type: PyAttributeType, + group: Optional[Group] = None, + ) -> None: ... + def set_edge_attribute( + self, + attribute: MedRecordAttribute, + data_type: PyDataType, + attribute_type: PyAttributeType, + group: Optional[Group] = None, + ) -> None: ... + def update_node_attribute( + self, + attribute: MedRecordAttribute, + data_type: PyDataType, + attribute_type: PyAttributeType, + group: Optional[Group] = None, + ) -> None: ... + def update_edge_attribute( + self, + attribute: MedRecordAttribute, + data_type: PyDataType, + attribute_type: PyAttributeType, + group: Optional[Group] = None, + ) -> None: ... + def remove_node_attribute( + self, attribute: MedRecordAttribute, group: Optional[Group] = None + ) -> None: ... + def remove_edge_attribute( + self, attribute: MedRecordAttribute, group: Optional[Group] = None + ) -> None: ... + def remove_group(self, group: Group) -> None: ... + def freeze(self) -> None: ... + def unfreeze(self) -> None: ... class PyMedRecord: schema: PySchema @@ -122,6 +175,8 @@ class PyMedRecord: def from_ron(path: str) -> PyMedRecord: ... def to_ron(self, path: str) -> None: ... def update_schema(self, schema: PySchema) -> None: ... + def freeze_schema(self) -> None: ... + def unfreeze_schema(self) -> None: ... def node(self, node_index: NodeIndexInputList) -> Dict[NodeIndex, Attributes]: ... def edge(self, edge_index: EdgeIndexInputList) -> Dict[EdgeIndex, Attributes]: ... def outgoing_edges( @@ -230,9 +285,9 @@ class PyMedRecord: def clone(self) -> PyMedRecord: ... class PyEdgeDirection(Enum): - Incoming = 0 - Outgoing = 1 - Both = 2 + Incoming = ... + Outgoing = ... + Both = ... class PyNodeOperand: def attribute(self, attribute: MedRecordAttribute) -> PyMultipleValuesOperand: ... diff --git a/medmodels/medrecord/datatype.py b/medmodels/medrecord/datatype.py index 92a3100e..9aa92df4 100644 --- a/medmodels/medrecord/datatype.py +++ b/medmodels/medrecord/datatype.py @@ -4,7 +4,7 @@ import typing from abc import ABC, abstractmethod -from typing import TypeAlias +from typing import Generic, TypeAlias, TypeVar from medmodels._medmodels import ( PyAny, @@ -344,28 +344,23 @@ def __eq__(self, value: object) -> bool: return isinstance(value, Any) -class Union(DataType): +U1 = TypeVar("U1", bound=DataType) +U2 = TypeVar("U2", bound=DataType) + + +class Union(DataType, Generic[U1, U2]): """Data type for unions of data types.""" _union: PyUnion - def __init__(self, *dtypes: DataType) -> None: + def __init__(self, dtype1: U1, dtype2: U2) -> None: """Initializes the Union data type. Args: - *dtypes (DataType): The data types to include in the union. - - Raises: - ValueError: If the union does not have at least two arguments. + dtype1 (U1): The first data type of the union. + dtype2 (U2): The second data type of the union. """ - if len(dtypes) < 2: - msg = "Union must have at least two arguments" - raise ValueError(msg) - - if len(dtypes) == 2: - self._union = PyUnion(dtypes[0]._inner(), dtypes[1]._inner()) - else: - self._union = PyUnion(dtypes[0]._inner(), Union(*dtypes[1:])._inner()) + self._union = PyUnion(dtype1._inner(), dtype2._inner()) def _inner(self) -> PyDataType: return self._union @@ -397,16 +392,19 @@ def __eq__(self, value: object) -> bool: ) -class Option(DataType): +T = TypeVar("T", bound=DataType) + + +class Option(DataType, Generic[T]): """Data type for optional values.""" _option: PyOption - def __init__(self, dtype: DataType) -> None: + def __init__(self, dtype: T) -> None: """Initializes the Option data type. Args: - dtype (DataType): The data type of the optional value. + dtype (T): The data type of the optional value. """ self._option = PyOption(dtype._inner()) diff --git a/medmodels/medrecord/schema.py b/medmodels/medrecord/schema.py index 0e332fd1..ee00798f 100644 --- a/medmodels/medrecord/schema.py +++ b/medmodels/medrecord/schema.py @@ -3,7 +3,17 @@ from __future__ import annotations from enum import Enum, auto -from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union, overload +from typing import ( + TYPE_CHECKING, + Dict, + List, + Literal, + Optional, + Tuple, + TypeAlias, + Union, + overload, +) from medmodels._medmodels import ( PyAttributeDataType, @@ -11,10 +21,28 @@ PyGroupSchema, PySchema, ) -from medmodels.medrecord.datatype import DataType +from medmodels.medrecord.datatype import ( + DataType, + DateTime, + Duration, + Float, + Int, + Null, + Option, +) +from medmodels.medrecord.datatype import ( + Union as DataTypeUnion, +) +from medmodels.medrecord.types import ( + Attributes, + EdgeIndex, + MedRecordAttribute, + NodeIndex, +) if TYPE_CHECKING: - from medmodels.medrecord.types import Group, MedRecordAttribute + from medmodels.medrecord.medrecord import MedRecord + from medmodels.medrecord.types import Group class AttributeType(Enum): @@ -23,6 +51,7 @@ class AttributeType(Enum): Categorical = auto() Continuous = auto() Temporal = auto() + Unstructured = auto() @staticmethod def _from_py_attribute_type(py_attribute_type: PyAttributeType) -> AttributeType: @@ -40,7 +69,24 @@ def _from_py_attribute_type(py_attribute_type: PyAttributeType) -> AttributeType return AttributeType.Continuous if py_attribute_type == PyAttributeType.Temporal: return AttributeType.Temporal - return None + if py_attribute_type == PyAttributeType.Unstructured: + return AttributeType.Unstructured + msg = "Should never be reached" + raise NotImplementedError(msg) + + @staticmethod + def infer(data_type: DataType) -> AttributeType: + """Infers the attribute type from the data type. + + Args: + data_type (DataType): The data type to infer the attribute type from. + + Returns: + AttributeType: The inferred attribute type. + """ + return AttributeType._from_py_attribute_type( + PyAttributeType.infer(data_type._inner()) + ) def _into_py_attribute_type(self) -> PyAttributeType: """Converts an AttributeType to a PyAttributeType. @@ -54,6 +100,8 @@ def _into_py_attribute_type(self) -> PyAttributeType: return PyAttributeType.Continuous if self == AttributeType.Temporal: return PyAttributeType.Temporal + if self == AttributeType.Unstructured: + return PyAttributeType.Unstructured msg = "Should never be reached" raise NotImplementedError(msg) @@ -73,6 +121,14 @@ def __str__(self) -> str: """ return self.name + def __hash__(self) -> int: + """Returns the hash of the AttributeType instance. + + Returns: + int: The hash of the AttributeType instance. + """ + return hash(self.name) + def __eq__(self, value: object) -> bool: """Compares the AttributeType instance to another object for equality. @@ -90,164 +146,37 @@ def __eq__(self, value: object) -> bool: return False -class AttributesSchema: - """A schema for a collection of attributes.""" - - _attributes_schema: Dict[ - MedRecordAttribute, Tuple[DataType, Optional[AttributeType]] - ] - - def __init__( - self, - attributes_schema: Dict[ - MedRecordAttribute, Tuple[DataType, Optional[AttributeType]] - ], - ) -> None: - """Initializes a new instance of AttributesSchema. - - Args: - attributes_schema (Dict[MedRecordAttribute, Tuple[DataType, Optional[AttributeType]]]): - A dictionary mapping MedRecordAttributes to their data types and - optional attribute types. - """ # noqa: W505 - self._attributes_schema = attributes_schema - - def __repr__(self) -> str: - """Returns a string representation of the AttributesSchema instance. - - Returns: - str: String representation of the attribute schema. - """ - return self._attributes_schema.__repr__() - - def __getitem__( - self, key: MedRecordAttribute - ) -> Tuple[DataType, Optional[AttributeType]]: - """Gets the type and optional attribute type for a given MedRecordAttribute. - - Args: - key (MedRecordAttribute): The attribute for which the data type is - requested. - - Returns: - Tuple[DataType, Optional[AttributeType]]: The data type and optional - attribute type of the given attribute. - """ - return self._attributes_schema[key] - - def __contains__(self, key: MedRecordAttribute) -> bool: - """Checks if a given MedRecordAttribute is in the attributes schema. - - Args: - key (MedRecordAttribute): The attribute to check. - - Returns: - bool: True if the attribute exists in the schema, False otherwise. - """ - return key in self._attributes_schema - - def __iter__(self) -> Iterator[MedRecordAttribute]: - """Returns an iterator over the attributes schema. - - Returns: - Iterator: An iterator over the attribute keys. - """ - return self._attributes_schema.__iter__() - - def __len__(self) -> int: - """Returns the number of attributes in the schema. - - Returns: - int: The number of attributes. - """ - return len(self._attributes_schema) +CategoricalType: TypeAlias = DataType +CategoricalPair: TypeAlias = Tuple[CategoricalType, Literal[AttributeType.Categorical]] - def __eq__(self, value: object) -> bool: - """Compares the AttributesSchema instance to another object for equality. +ContinuousType: TypeAlias = Union[ + Int, + Float, + Null, + Option["ContinuousType"], + DataTypeUnion["ContinuousType", "ContinuousType"], +] +ContinuousPair: TypeAlias = Tuple[ContinuousType, Literal[AttributeType.Continuous]] - Args: - value (object): The object to compare against. +TemporalType = Union[ + DateTime, + Duration, + Null, + Option["TemporalType"], + DataTypeUnion["TemporalType", "TemporalType"], +] +TemporalPair: TypeAlias = Tuple[TemporalType, Literal[AttributeType.Temporal]] - Returns: - bool: True if the objects are equal, False otherwise. - """ - if not (isinstance(value, (AttributesSchema, dict))): - return False - - attribute_schema = ( - value._attributes_schema if isinstance(value, AttributesSchema) else value - ) - - if not attribute_schema.keys() == self._attributes_schema.keys(): - return False - - for key in self._attributes_schema: - if ( - not isinstance(attribute_schema[key], tuple) - or not isinstance( - attribute_schema[key][0], type(self._attributes_schema[key][0]) - ) - or attribute_schema[key][1] != self._attributes_schema[key][1] - ): - return False +UnstructuredType: TypeAlias = DataType +UnstructuredPair: TypeAlias = Tuple[ + UnstructuredType, Literal[AttributeType.Unstructured] +] - return True +AttributeDataType: TypeAlias = Union[ + CategoricalPair, ContinuousPair, TemporalPair, UnstructuredPair +] - def keys(self): # noqa: ANN201 - """Returns the attribute keys in the schema. - - Returns: - KeysView: A view object displaying a list of dictionary's keys. - """ - return self._attributes_schema.keys() - - def values(self): # noqa: ANN201 - """Returns the attribute values in the schema. - - Returns: - ValuesView: A view object displaying a list of dictionary's values. - """ - return self._attributes_schema.values() - - def items(self): # noqa: ANN201 - """Returns the attribute key-value pairs in the schema. - - Returns: - ItemsView: A set-like object providing a view on D's items. - """ - return self._attributes_schema.items() - - @overload - def get( - self, key: MedRecordAttribute - ) -> Optional[Tuple[DataType, Optional[AttributeType]]]: ... - - @overload - def get( - self, key: MedRecordAttribute, default: Tuple[DataType, Optional[AttributeType]] - ) -> Tuple[DataType, Optional[AttributeType]]: ... - - def get( - self, - key: MedRecordAttribute, - default: Optional[Tuple[DataType, Optional[AttributeType]]] = None, - ) -> Optional[Tuple[DataType, Optional[AttributeType]]]: - """Gets the data type and optional attribute type for a given attribute. - - It returns a default value if the attribute is not present. - - Args: - key (MedRecordAttribute): The attribute for which the data type is - requested. - default (Optional[Tuple[DataType, Optional[AttributeType]]], optional): - The default data type and attribute type to return if the attribute - is not found. Defaults to None. - - Returns: - Optional[Tuple[DataType, Optional[AttributeType]]]: The data type and - optional attribute type of the given attribute or the default value. - """ - return self._attributes_schema.get(key, default) +AttributesSchema: TypeAlias = Dict[MedRecordAttribute, AttributeDataType] class GroupSchema: @@ -259,43 +188,50 @@ def __init__( self, *, nodes: Optional[ - Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]] + Dict[ + MedRecordAttribute, + Union[DataType, AttributeDataType], + ], ] = None, edges: Optional[ - Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]] + Dict[ + MedRecordAttribute, + Union[DataType, AttributeDataType], + ], ] = None, - strict: bool = False, ) -> None: """Initializes a new instance of GroupSchema. Args: - nodes (Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]]): + nodes (Dict[MedRecordAttribute, Union[DataType, AttributeDataType]]): A dictionary mapping node attributes to their data types and optional attribute types. Defaults to an empty dictionary. - edges (Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]]): + When no attribute type is provided, it is inferred from the data type. + edges (Dict[MedRecordAttribute, Union[DataType, AttributeDataType]]): A dictionary mapping edge attributes to their data types and optional attribute types. Defaults to an empty dictionary. - strict (bool, optional): Indicates whether the schema should be strict. - Defaults to False. - """ # noqa: W505 + When no attribute type is provided, it is inferred from the data type. + """ if edges is None: edges = {} if nodes is None: nodes = {} def _convert_input( - input: Union[DataType, Tuple[DataType, AttributeType]], + input: Union[DataType, AttributeDataType], ) -> PyAttributeDataType: if isinstance(input, tuple): return PyAttributeDataType( input[0]._inner(), input[1]._into_py_attribute_type() ) - return PyAttributeDataType(input._inner(), None) + + return PyAttributeDataType( + input._inner(), PyAttributeType.infer(input._inner()) + ) self._group_schema = PyGroupSchema( nodes={x: _convert_input(nodes[x]) for x in nodes}, edges={x: _convert_input(edges[x]) for x in edges}, - strict=strict, ) @classmethod @@ -323,20 +259,17 @@ def nodes(self) -> AttributesSchema: def _convert_node( input: PyAttributeDataType, - ) -> Tuple[DataType, Optional[AttributeType]]: + ) -> AttributeDataType: + # SAFETY: The typing is guaranteed to be correct return ( DataType._from_py_data_type(input.data_type), - AttributeType._from_py_attribute_type(input.attribute_type) - if input.attribute_type is not None - else None, - ) + AttributeType._from_py_attribute_type(input.attribute_type), + ) # pyright: ignore[reportReturnType] - return AttributesSchema( - { - x: _convert_node(self._group_schema.nodes[x]) - for x in self._group_schema.nodes - } - ) + return { + x: _convert_node(self._group_schema.nodes[x]) + for x in self._group_schema.nodes + } @property def edges(self) -> AttributesSchema: @@ -349,29 +282,35 @@ def edges(self) -> AttributesSchema: def _convert_edge( input: PyAttributeDataType, - ) -> Tuple[DataType, Optional[AttributeType]]: + ) -> AttributeDataType: + # SAFETY: The typing is guaranteed to be correct return ( DataType._from_py_data_type(input.data_type), - AttributeType._from_py_attribute_type(input.attribute_type) - if input.attribute_type is not None - else None, - ) + AttributeType._from_py_attribute_type(input.attribute_type), + ) # pyright: ignore[reportReturnType] - return AttributesSchema( - { - x: _convert_edge(self._group_schema.edges[x]) - for x in self._group_schema.edges - } - ) + return { + x: _convert_edge(self._group_schema.edges[x]) + for x in self._group_schema.edges + } - @property - def strict(self) -> Optional[bool]: - """Indicates whether the GroupSchema instance is strict. + def validate_node(self, index: NodeIndex, attributes: Attributes) -> None: + """Validates the attributes of a node. - Returns: - Optional[bool]: True if the schema is strict, False otherwise. + Args: + index (NodeIndex): The index of the node. + attributes (Attributes): The attributes of the node. """ - return self._group_schema.strict + self._group_schema.validate_node(index, attributes) + + def validate_edge(self, index: EdgeIndex, attributes: Attributes) -> None: + """Validates the attributes of an edge. + + Args: + index (EdgeIndex): The index of the edge. + attributes (Attributes): The attributes of the edge. + """ + self._group_schema.validate_edge(index, attributes) class Schema: @@ -384,7 +323,6 @@ def __init__( *, groups: Optional[Dict[Group, GroupSchema]] = None, default: Optional[GroupSchema] = None, - strict: bool = False, ) -> None: """Initializes a new instance of Schema. @@ -392,23 +330,32 @@ def __init__( groups (Dict[Group, GroupSchema], optional): A dictionary of group names to their schemas. Defaults to an empty dictionary. default (Optional[GroupSchema], optional): The default group schema. - Defaults to None. - strict (bool, optional): Indicates whether the schema should be strict. - Defaults to False. + If not provided, an empty group schema is used. Defaults to None. """ + if not default: + default = GroupSchema() + if groups is None: groups = {} - if default is not None: - self._schema = PySchema( - groups={x: groups[x]._group_schema for x in groups}, - default=default._group_schema, - strict=strict, - ) - else: - self._schema = PySchema( - groups={x: groups[x]._group_schema for x in groups}, - strict=strict, - ) + + self._schema = PySchema( + groups={x: groups[x]._group_schema for x in groups}, + default=default._group_schema, + ) + + @classmethod + def infer(cls, medrecord: MedRecord) -> Schema: + """Infers a schema from a MedRecord instance. + + Args: + medrecord (MedRecord): The MedRecord instance to infer the schema from. + + Returns: + Schema: The inferred schema. + """ + new_schema = cls() + new_schema._schema = PySchema.infer(medrecord._medrecord) + return new_schema @classmethod def _from_py_schema(cls, schema: PySchema) -> Schema: @@ -448,23 +395,279 @@ def group(self, group: Group) -> GroupSchema: return GroupSchema._from_pygroupschema(self._schema.group(group)) @property - def default(self) -> Optional[GroupSchema]: + def default(self) -> GroupSchema: """Retrieves the default group schema. Returns: Optional[GroupSchema]: The default group schema if it exists, otherwise None. """ - if self._schema.default is None: - return None - return GroupSchema._from_pygroupschema(self._schema.default) - @property - def strict(self) -> Optional[bool]: - """Indicates whether the Schema instance is strict. + def validate_node( + self, index: NodeIndex, attributes: Attributes, group: Optional[Group] = None + ) -> None: + """Validates the attributes of a node. - Returns: - Optional[bool]: True if the schema is strict, False otherwise. + Args: + index (NodeIndex): The index of the node. + attributes (Attributes): The attributes of the node. + group (Optional[Group], optional): The group to validate the node against. + If not provided, the default group is used. Defaults to None. + """ + self._schema.validate_node(index, attributes, group) + + def validate_edge( + self, index: EdgeIndex, attributes: Attributes, group: Optional[Group] = None + ) -> None: + """Validates the attributes of an edge. + + Args: + index (EdgeIndex): The index of the edge. + attributes (Attributes): The attributes of the edge. + group (Optional[Group], optional): The group to validate the edge against. + If not provided, the default group is used. Defaults to None. """ - return self._schema.strict + self._schema.validate_edge(index, attributes, group) + + @overload + def set_node_attribute( + self, + attribute: MedRecordAttribute, + data_type: DataType, + attribute_type: Optional[ + Literal[AttributeType.Categorical, AttributeType.Unstructured] + ], + group: Optional[Group] = None, + ) -> None: ... + + @overload + def set_node_attribute( + self, + attribute: MedRecordAttribute, + data_type: ContinuousType, + attribute_type: Literal[AttributeType.Continuous], + group: Optional[Group] = None, + ) -> None: ... + + @overload + def set_node_attribute( + self, + attribute: MedRecordAttribute, + data_type: TemporalType, + attribute_type: Literal[AttributeType.Temporal], + group: Optional[Group] = None, + ) -> None: ... + + def set_node_attribute( + self, + attribute: MedRecordAttribute, + data_type: DataType, + attribute_type: Optional[AttributeType] = None, + group: Optional[Group] = None, + ) -> None: + """Sets the data type and attribute type of a node attribute. + + If a data type for the attribute already exists, it is overwritten. + + Args: + attribute (MedRecordAttribute): The name of the attribute. + data_type (DataType): The data type of the attribute. + attribute_type (Optional[AttributeType], optional): The attribute type of + the attribute. If not provided, the attribute type is inferred + from the data type. Defaults to None. + group (Optional[Group], optional): The group to set the attribute for. + If no schema for the group exists, a new schema is created. + If not provided, the default group is used. Defaults to None. + """ + if not attribute_type: + attribute_type = AttributeType.infer(data_type) + + self._schema.set_node_attribute( + attribute, + data_type._inner(), + attribute_type._into_py_attribute_type(), + group, + ) + + @overload + def set_edge_attribute( + self, + attribute: MedRecordAttribute, + data_type: DataType, + attribute_type: Optional[ + Literal[AttributeType.Categorical, AttributeType.Unstructured] + ], + group: Optional[Group] = None, + ) -> None: ... + + @overload + def set_edge_attribute( + self, + attribute: MedRecordAttribute, + data_type: ContinuousType, + attribute_type: Literal[AttributeType.Continuous], + group: Optional[Group] = None, + ) -> None: ... + + @overload + def set_edge_attribute( + self, + attribute: MedRecordAttribute, + data_type: TemporalType, + attribute_type: Literal[AttributeType.Temporal], + group: Optional[Group] = None, + ) -> None: ... + + def set_edge_attribute( + self, + attribute: MedRecordAttribute, + data_type: DataType, + attribute_type: Optional[AttributeType] = None, + group: Optional[Group] = None, + ) -> None: + """Sets the data type and attribute type of an edge attribute. + + If a data type for the attribute already exists, it is overwritten. + + Args: + attribute (MedRecordAttribute): The name of the attribute. + data_type (DataType): The data type of the attribute. + attribute_type (Optional[AttributeType], optional): The attribute type of + the attribute. If not provided, the attribute type is inferred + from the data type. Defaults to None. + group (Optional[Group], optional): The group to set the attribute for. + If no schema for this group exists, a new schema is created. + If not provided, the default group is used. Defaults to None. + """ + if not attribute_type: + attribute_type = AttributeType.infer(data_type) + + self._schema.set_edge_attribute( + attribute, + data_type._inner(), + attribute_type._into_py_attribute_type(), + group, + ) + + @overload + def update_node_attribute( + self, + attribute: MedRecordAttribute, + data_type: DataType, + attribute_type: Optional[ + Literal[AttributeType.Categorical, AttributeType.Unstructured] + ], + group: Optional[Group] = None, + ) -> None: ... + + @overload + def update_node_attribute( + self, + attribute: MedRecordAttribute, + data_type: ContinuousType, + attribute_type: Literal[AttributeType.Continuous], + group: Optional[Group] = None, + ) -> None: ... + + @overload + def update_node_attribute( + self, + attribute: MedRecordAttribute, + data_type: TemporalType, + attribute_type: Literal[AttributeType.Temporal], + group: Optional[Group] = None, + ) -> None: ... + + def update_node_attribute( + self, + attribute: MedRecordAttribute, + data_type: DataType, + attribute_type: Optional[AttributeType] = None, + group: Optional[Group] = None, + ) -> None: + """Updates the data type and attribute type of a node attribute. + + If a data type for the attribute already exists, it is merged + with the new data type. + + Args: + attribute (MedRecordAttribute): The name of the attribute. + data_type (DataType): The data type of the attribute. + attribute_type (Optional[AttributeType], optional): The attribute type of + the attribute. If not provided, the attribute type is inferred + from the data type. Defaults to None. + group (Optional[Group], optional): The group to update the attribute for. + If no schema for this group exists, a new schema is created. + If not provided, the default group is used. Defaults to None. + """ + if not attribute_type: + attribute_type = AttributeType.infer(data_type) + + self._schema.update_node_attribute( + attribute, + data_type._inner(), + attribute_type._into_py_attribute_type(), + group, + ) + + @overload + def update_edge_attribute( + self, + attribute: MedRecordAttribute, + data_type: DataType, + attribute_type: Optional[ + Literal[AttributeType.Categorical, AttributeType.Unstructured] + ], + group: Optional[Group] = None, + ) -> None: ... + + @overload + def update_edge_attribute( + self, + attribute: MedRecordAttribute, + data_type: ContinuousType, + attribute_type: Literal[AttributeType.Continuous], + group: Optional[Group] = None, + ) -> None: ... + + @overload + def update_edge_attribute( + self, + attribute: MedRecordAttribute, + data_type: TemporalType, + attribute_type: Literal[AttributeType.Temporal], + group: Optional[Group] = None, + ) -> None: ... + + def update_edge_attribute( + self, + attribute: MedRecordAttribute, + data_type: DataType, + attribute_type: Optional[AttributeType] = None, + group: Optional[Group] = None, + ) -> None: + """Updates the data type and attribute type of an edge attribute. + + If a data type for the attribute already exists, it is merged + with the new data type. + + Args: + attribute (MedRecordAttribute): The name of the attribute. + data_type (DataType): The data type of the attribute. + attribute_type (Optional[AttributeType], optional): The attribute type of + the attribute. If not provided, the attribute type is inferred + from the data type. Defaults to None. + group (Optional[Group], optional): The group to update the attribute for. + If no schema for this group exists, a new schema is created. + If not provided, the default group is used. Defaults to None. + """ + if not attribute_type: + attribute_type = AttributeType.infer(data_type) + + self._schema.update_edge_attribute( + attribute, + data_type._inner(), + attribute_type._into_py_attribute_type(), + group, + ) diff --git a/medmodels/medrecord/tests/test_builder.py b/medmodels/medrecord/tests/test_builder.py index d4570fde..3b77aa86 100644 --- a/medmodels/medrecord/tests/test_builder.py +++ b/medmodels/medrecord/tests/test_builder.py @@ -64,7 +64,7 @@ def test_with_schema(self) -> None: with pytest.raises( ValueError, - match=r"Attribute attribute of node with index node2 is of type String\. Expected Integer\.", + match=r"Attribute attribute of node with index node2 is of type String\. Expected Int\.", ): medrecord.add_nodes(("node2", {"attribute": "1"})) diff --git a/medmodels/medrecord/tests/test_datatype.py b/medmodels/medrecord/tests/test_datatype.py index 05a6bb1d..6a3c1735 100644 --- a/medmodels/medrecord/tests/test_datatype.py +++ b/medmodels/medrecord/tests/test_datatype.py @@ -1,7 +1,5 @@ import unittest -import pytest - import medmodels.medrecord as mr from medmodels._medmodels import ( PyAny, @@ -102,7 +100,7 @@ def test_union(self) -> None: assert union.__repr__() == "DataType.Union(DataType.String, DataType.Int)" - union = mr.Union(mr.String(), mr.Int(), mr.Bool()) + union = mr.Union(mr.String(), mr.Union(mr.Int(), mr.Bool())) assert isinstance(union._inner(), PyUnion) assert str(union) == "Union(String, Union(Int, Bool))" @@ -115,10 +113,6 @@ def test_union(self) -> None: assert mr.Union(mr.String(), mr.Int()) == mr.Union(mr.String(), mr.Int()) assert mr.Union(mr.String(), mr.Int()) != mr.Union(mr.Int(), mr.String()) - def test_invalid_union(self) -> None: - with pytest.raises(ValueError, match="Union must have at least two arguments"): - mr.Union(mr.String()) - def test_option(self) -> None: option = mr.Option(mr.String()) assert isinstance(option._inner(), PyOption) diff --git a/medmodels/medrecord/tests/test_medrecord.py b/medmodels/medrecord/tests/test_medrecord.py index 1d604d93..73e6d071 100644 --- a/medmodels/medrecord/tests/test_medrecord.py +++ b/medmodels/medrecord/tests/test_medrecord.py @@ -6,7 +6,6 @@ import polars as pl import pytest -import medmodels.medrecord as mr from medmodels import MedRecord from medmodels.medrecord.medrecord import EdgesDirected from medmodels.medrecord.querying import EdgeOperand, NodeOperand @@ -268,67 +267,6 @@ def test_ron(self) -> None: assert medrecord.node_count() == loaded_medrecord.node_count() assert medrecord.edge_count() == loaded_medrecord.edge_count() - def test_schema(self) -> None: - schema = mr.Schema( - groups={ - "group": mr.GroupSchema( - nodes={"attribute2": mr.Int()}, edges={"attribute2": mr.Int()} - ) - }, - default=mr.GroupSchema( - nodes={"attribute": mr.Int()}, edges={"attribute": mr.Int()} - ), - ) - - medrecord = MedRecord.with_schema(schema) - medrecord.add_group("group") - - medrecord.add_nodes(("0", {"attribute": 1})) - - with pytest.raises( - ValueError, - match=r"Attribute [^\s]+ of node with index [^\s]+ is of type [^\s]+. Expected [^\s]+.", - ): - medrecord.add_nodes(("1", {"attribute": "1"})) - - medrecord.add_nodes(("1", {"attribute": 1, "attribute2": 1})) - - medrecord.add_nodes_to_group("group", "1") - - medrecord.add_nodes(("2", {"attribute": 1, "attribute2": "1"})) - - with pytest.raises( - ValueError, - match=r"Attribute [^\s]+ of node with index [^\s]+ is of type [^\s]+. Expected [^\s]+.", - ): - medrecord.add_nodes_to_group("group", "2") - - medrecord.add_edges(("0", "1", {"attribute": 1})) - - with pytest.raises( - ValueError, - match=r"Attribute [^\s]+ of edge with index [^\s]+ is of type [^\s]+. Expected [^\s]+.", - ): - medrecord.add_edges(("0", "1", {"attribute": "1"})) - - edge_index = medrecord.add_edges(("0", "1", {"attribute": 1, "attribute2": 1})) - - medrecord.add_edges_to_group("group", edge_index) - - edge_index = medrecord.add_edges( - ( - "0", - "1", - {"attribute": 1, "attribute2": "1"}, - ) - ) - - with pytest.raises( - ValueError, - match=r"Attribute [^\s]+ of edge with index [^\s]+ is of type [^\s]+. Expected [^\s]+.", - ): - medrecord.add_edges_to_group("group", edge_index) - def test_nodes(self) -> None: medrecord = create_medrecord() diff --git a/medmodels/medrecord/tests/test_schema.py b/medmodels/medrecord/tests/test_schema.py index 2d2e2eb2..ebf6bf92 100644 --- a/medmodels/medrecord/tests/test_schema.py +++ b/medmodels/medrecord/tests/test_schema.py @@ -1,230 +1,106 @@ import unittest -import pytest - import medmodels.medrecord as mr from medmodels._medmodels import PyAttributeType -from medmodels.medrecord.schema import GroupSchema, Schema def create_medrecord() -> mr.MedRecord: return mr.MedRecord.from_simple_example_dataset() -class TestSchema(unittest.TestCase): - def setUp(self) -> None: - self.schema = create_medrecord().schema - - def test_groups(self) -> None: - assert sorted( - [ - "diagnosis", - "drug", - "patient_diagnosis", - "patient_drug", - "patient_procedure", - "patient", - "procedure", - ] - ) == sorted(self.schema.groups) - - def test_group(self) -> None: - assert isinstance(self.schema.group("patient"), mr.GroupSchema) # pyright: ignore[reportUnnecessaryIsInstance] - - with pytest.raises(ValueError, match="No schema found for group: nonexistent"): - self.schema.group("nonexistent") - - def test_default(self) -> None: - assert None is self.schema.default - schema = Schema(default=GroupSchema(nodes={"description": mr.String()})) - - assert isinstance(schema.default, mr.GroupSchema) - - def test_strict(self) -> None: - assert True is self.schema.strict - - -class TestGroupSchema(unittest.TestCase): - def setUp(self) -> None: - self.schema = create_medrecord().schema - - def test_nodes(self) -> None: - assert self.schema.group("patient").nodes == { - "age": (mr.Int(), mr.AttributeType.Continuous), - "gender": (mr.String(), mr.AttributeType.Categorical), - } - - def test_edges(self) -> None: - assert self.schema.group("patient_diagnosis").edges == { - "time": (mr.DateTime(), mr.AttributeType.Temporal), - "duration_days": (mr.Option(mr.Float()), mr.AttributeType.Continuous), - } - - def test_strict(self) -> None: - assert True is self.schema.group("patient").strict - - -class TestAttributesSchema(unittest.TestCase): - def setUp(self) -> None: - self.attributes_schema = ( - Schema( - groups={"diagnosis": GroupSchema(nodes={"description": mr.String()})}, - strict=False, - ) - .group("diagnosis") - .nodes - ) - - def test_repr(self) -> None: +class TestAttributeType(unittest.TestCase): + def test_from_py_attribute_type(self) -> None: assert ( - repr(self.attributes_schema) == "{'description': (DataType.String, None)}" + mr.AttributeType._from_py_attribute_type(PyAttributeType.Categorical) + == mr.AttributeType.Categorical ) - - second_attributes_schema = ( - Schema( - groups={ - "diagnosis": GroupSchema( - nodes={ - "description": (mr.String(), mr.AttributeType.Categorical) - } - ) - }, - strict=False, - ) - .group("diagnosis") - .nodes - ) - assert ( - repr(second_attributes_schema) - == "{'description': (DataType.String, AttributeType.Categorical)}" + mr.AttributeType._from_py_attribute_type(PyAttributeType.Continuous) + == mr.AttributeType.Continuous ) - - def test_getitem(self) -> None: - assert (mr.String(), None) == self.attributes_schema["description"] - - with pytest.raises(KeyError): - self.attributes_schema["nonexistent"] - - def test_contains(self) -> None: - assert "description" in self.attributes_schema - assert "nonexistent" not in self.attributes_schema - - def test_len(self) -> None: - assert len(self.attributes_schema) == 1 - - def test_eq(self) -> None: - comparison_attributes_schema = ( - Schema( - groups={"diagnosis": GroupSchema(nodes={"description": mr.String()})}, - strict=False, - ) - .group("diagnosis") - .nodes + assert ( + mr.AttributeType._from_py_attribute_type(PyAttributeType.Temporal) + == mr.AttributeType.Temporal ) - - assert self.attributes_schema == comparison_attributes_schema - - comparison_attributes_schema = ( - Schema( - groups={"diagnosis": GroupSchema(nodes={"description": mr.Int()})}, - strict=False, - ) - .group("diagnosis") - .nodes + assert ( + mr.AttributeType._from_py_attribute_type(PyAttributeType.Unstructured) + == mr.AttributeType.Unstructured ) - assert self.attributes_schema != comparison_attributes_schema - - comparison_attributes_schema = ( - Schema( - groups={ - "diagnosis": GroupSchema( - nodes={ - "description": (mr.String(), mr.AttributeType.Categorical) - } - ) - }, - strict=False, - ) - .group("diagnosis") - .nodes + def test_into_py_attribute_type(self) -> None: + assert ( + mr.AttributeType.Categorical._into_py_attribute_type() + == PyAttributeType.Categorical ) - - assert self.attributes_schema != comparison_attributes_schema - - comparison_attributes_schema = ( - Schema( - groups={ - "diagnosis": GroupSchema( - nodes={ - "description2": (mr.String(), mr.AttributeType.Categorical) - } - ) - }, - strict=False, - ) - .group("diagnosis") - .nodes + assert ( + mr.AttributeType.Continuous._into_py_attribute_type() + == PyAttributeType.Continuous ) - - assert self.attributes_schema != comparison_attributes_schema - - assert self.attributes_schema is not None - - def test_keys(self) -> None: - assert list(self.attributes_schema.keys()) == ["description"] - - def test_values(self) -> None: - assert [(mr.String(), None)] == list(self.attributes_schema.values()) - - def test_items(self) -> None: - assert [("description", (mr.String(), None))] == list( - self.attributes_schema.items() + assert ( + mr.AttributeType.Temporal._into_py_attribute_type() + == PyAttributeType.Temporal ) - - def test_get(self) -> None: - assert (mr.String(), None) == self.attributes_schema.get("description") - - assert None is self.attributes_schema.get("nonexistent") - - assert (mr.String(), None) == self.attributes_schema.get( - "nonexistent", (mr.String(), None) + assert ( + mr.AttributeType.Unstructured._into_py_attribute_type() + == PyAttributeType.Unstructured ) + def test_repr(self) -> None: + assert repr(mr.AttributeType.Categorical) == "AttributeType.Categorical" + assert repr(mr.AttributeType.Continuous) == "AttributeType.Continuous" + assert repr(mr.AttributeType.Temporal) == "AttributeType.Temporal" + assert repr(mr.AttributeType.Unstructured) == "AttributeType.Unstructured" -class TestAttributeType(unittest.TestCase): def test_str(self) -> None: assert str(mr.AttributeType.Categorical) == "Categorical" assert str(mr.AttributeType.Continuous) == "Continuous" assert str(mr.AttributeType.Temporal) == "Temporal" + assert str(mr.AttributeType.Unstructured) == "Unstructured" + + def test_hash(self) -> None: + assert hash(mr.AttributeType.Categorical) == hash("Categorical") + assert hash(mr.AttributeType.Continuous) == hash("Continuous") + assert hash(mr.AttributeType.Temporal) == hash("Temporal") + assert hash(mr.AttributeType.Unstructured) == hash("Unstructured") def test_eq(self) -> None: assert mr.AttributeType.Categorical == mr.AttributeType.Categorical assert mr.AttributeType.Categorical == PyAttributeType.Categorical assert mr.AttributeType.Categorical != mr.AttributeType.Continuous + assert mr.AttributeType.Categorical != mr.AttributeType.Temporal + assert mr.AttributeType.Categorical != mr.AttributeType.Unstructured assert mr.AttributeType.Categorical != PyAttributeType.Continuous + assert mr.AttributeType.Categorical != PyAttributeType.Temporal + assert mr.AttributeType.Categorical != PyAttributeType.Unstructured assert mr.AttributeType.Continuous == mr.AttributeType.Continuous assert mr.AttributeType.Continuous == PyAttributeType.Continuous assert mr.AttributeType.Continuous != mr.AttributeType.Categorical + assert mr.AttributeType.Continuous != mr.AttributeType.Temporal + assert mr.AttributeType.Continuous != mr.AttributeType.Unstructured assert mr.AttributeType.Continuous != PyAttributeType.Categorical + assert mr.AttributeType.Continuous != PyAttributeType.Temporal + assert mr.AttributeType.Continuous != PyAttributeType.Unstructured assert mr.AttributeType.Temporal == mr.AttributeType.Temporal assert mr.AttributeType.Temporal == PyAttributeType.Temporal assert mr.AttributeType.Temporal != mr.AttributeType.Categorical + assert mr.AttributeType.Temporal != mr.AttributeType.Continuous + assert mr.AttributeType.Temporal != mr.AttributeType.Unstructured assert mr.AttributeType.Temporal != PyAttributeType.Categorical + assert mr.AttributeType.Temporal != PyAttributeType.Continuous + assert mr.AttributeType.Temporal != PyAttributeType.Unstructured + assert mr.AttributeType.Unstructured == mr.AttributeType.Unstructured + assert mr.AttributeType.Unstructured == PyAttributeType.Unstructured + assert mr.AttributeType.Unstructured != mr.AttributeType.Categorical + assert mr.AttributeType.Unstructured != mr.AttributeType.Continuous + assert mr.AttributeType.Unstructured != mr.AttributeType.Temporal + assert mr.AttributeType.Unstructured != PyAttributeType.Categorical + assert mr.AttributeType.Unstructured != PyAttributeType.Continuous + assert mr.AttributeType.Unstructured != PyAttributeType.Temporal -if __name__ == "__main__": - run_test = unittest.TestLoader().loadTestsFromTestCase(TestSchema) - unittest.TextTestRunner(verbosity=2).run(run_test) - - run_test = unittest.TestLoader().loadTestsFromTestCase(TestGroupSchema) - unittest.TextTestRunner(verbosity=2).run(run_test) - - run_test = unittest.TestLoader().loadTestsFromTestCase(TestAttributesSchema) - unittest.TextTestRunner(verbosity=2).run(run_test) +if __name__ == "__main__": run_test = unittest.TestLoader().loadTestsFromTestCase(TestAttributeType) unittest.TextTestRunner(verbosity=2).run(run_test) diff --git a/rustmodels/src/medrecord/errors.rs b/rustmodels/src/medrecord/errors.rs index f96f9a32..dac90703 100644 --- a/rustmodels/src/medrecord/errors.rs +++ b/rustmodels/src/medrecord/errors.rs @@ -1,4 +1,4 @@ -use medmodels_core::errors::MedRecordError; +use medmodels_core::errors::{GraphError, MedRecordError}; use pyo3::{ exceptions::{PyAssertionError, PyIndexError, PyKeyError, PyRuntimeError, PyValueError}, PyErr, @@ -13,6 +13,12 @@ impl From for PyMedRecordError { } } +impl From for PyMedRecordError { + fn from(error: GraphError) -> Self { + Self(MedRecordError::from(error)) + } +} + impl From for PyErr { fn from(error: PyMedRecordError) -> Self { match error.0 { diff --git a/rustmodels/src/medrecord/mod.rs b/rustmodels/src/medrecord/mod.rs index 30dd570c..16bcd944 100644 --- a/rustmodels/src/medrecord/mod.rs +++ b/rustmodels/src/medrecord/mod.rs @@ -30,8 +30,21 @@ type Lut = GILHashMap) -> PyResult>; #[pyclass] #[repr(transparent)] +#[derive(Debug, Clone)] pub struct PyMedRecord(MedRecord); +impl From for PyMedRecord { + fn from(value: MedRecord) -> Self { + Self(value) + } +} + +impl From for MedRecord { + fn from(value: PyMedRecord) -> Self { + value.0 + } +} + #[pymethods] impl PyMedRecord { #[new] @@ -105,7 +118,15 @@ impl PyMedRecord { #[getter] pub fn schema(&self) -> PySchema { - self.0.get_schema().clone().into() + self.0.schema().clone().into() + } + + pub fn freeze_schema(&mut self) { + self.0.freeze_schema() + } + + pub fn unfreeze_schema(&mut self) { + self.0.unfreeze_schema() } #[getter] diff --git a/rustmodels/src/medrecord/schema.rs b/rustmodels/src/medrecord/schema.rs index fa47e5a4..4813ab7b 100644 --- a/rustmodels/src/medrecord/schema.rs +++ b/rustmodels/src/medrecord/schema.rs @@ -1,8 +1,3 @@ -use medmodels_core::{ - errors::MedRecordError, - medrecord::{AttributeDataType, AttributeType, GroupSchema, Schema}, -}; -use pyo3::prelude::*; use std::collections::HashMap; use super::{ @@ -10,8 +5,15 @@ use super::{ datatype::PyDataType, errors::PyMedRecordError, traits::{DeepFrom, DeepInto}, - PyGroup, + PyAttributes, PyGroup, PyMedRecord, PyNodeIndex, }; +use medmodels_core::{ + errors::GraphError, + medrecord::{ + AttributeDataType, AttributeType, EdgeIndex, Group, GroupSchema, Schema, SchemaType, + }, +}; +use pyo3::prelude::*; #[pyclass(eq, eq_int)] #[derive(Debug, Clone, PartialEq)] @@ -19,6 +21,7 @@ pub enum PyAttributeType { Categorical = 0, Continuous = 1, Temporal = 2, + Unstructured = 3, } impl From for PyAttributeType { @@ -27,6 +30,7 @@ impl From for PyAttributeType { AttributeType::Categorical => Self::Categorical, AttributeType::Continuous => Self::Continuous, AttributeType::Temporal => Self::Temporal, + AttributeType::Unstructured => Self::Unstructured, } } } @@ -37,43 +41,45 @@ impl From for AttributeType { PyAttributeType::Categorical => Self::Categorical, PyAttributeType::Continuous => Self::Continuous, PyAttributeType::Temporal => Self::Temporal, + PyAttributeType::Unstructured => Self::Unstructured, } } } +#[pymethods] +impl PyAttributeType { + #[staticmethod] + pub fn infer(data_type: PyDataType) -> Self { + AttributeType::infer(&data_type.into()).into() + } +} + #[pyclass] #[derive(Debug, Clone)] pub struct PyAttributeDataType { data_type: PyDataType, - attribute_type: Option, -} - -impl From for AttributeDataType { - fn from(value: PyAttributeDataType) -> Self { - Self { - data_type: value.data_type.into(), - attribute_type: value.attribute_type.map(|t| t.into()), - } - } + attribute_type: PyAttributeType, } impl From for PyAttributeDataType { fn from(value: AttributeDataType) -> Self { Self { - data_type: value.data_type.into(), - attribute_type: value.attribute_type.map(|t| t.into()), + data_type: value.data_type().clone().into(), + attribute_type: (*value.attribute_type()).into(), } } } -impl DeepFrom for AttributeDataType { - fn deep_from(value: PyAttributeDataType) -> AttributeDataType { - value.into() +impl TryFrom for AttributeDataType { + type Error = GraphError; + + fn try_from(value: PyAttributeDataType) -> Result { + Self::new(value.data_type.into(), value.attribute_type.into()) } } impl DeepFrom for PyAttributeDataType { - fn deep_from(value: AttributeDataType) -> PyAttributeDataType { + fn deep_from(value: AttributeDataType) -> Self { value.into() } } @@ -81,8 +87,8 @@ impl DeepFrom for PyAttributeDataType { #[pymethods] impl PyAttributeDataType { #[new] - #[pyo3(signature = (data_type, attribute_type=None))] - pub fn new(data_type: PyDataType, attribute_type: Option) -> Self { + #[pyo3(signature = (data_type, attribute_type))] + pub fn new(data_type: PyDataType, attribute_type: PyAttributeType) -> Self { Self { data_type, attribute_type, @@ -95,7 +101,7 @@ impl PyAttributeDataType { } #[getter] - pub fn attribute_type(&self) -> Option { + pub fn attribute_type(&self) -> PyAttributeType { self.attribute_type.clone() } } @@ -117,14 +123,14 @@ impl From for GroupSchema { } } -impl DeepFrom for GroupSchema { - fn deep_from(value: PyGroupSchema) -> GroupSchema { +impl DeepFrom for PyGroupSchema { + fn deep_from(value: GroupSchema) -> Self { value.into() } } -impl DeepFrom for PyGroupSchema { - fn deep_from(value: GroupSchema) -> PyGroupSchema { +impl DeepFrom for GroupSchema { + fn deep_from(value: PyGroupSchema) -> Self { value.into() } } @@ -132,32 +138,73 @@ impl DeepFrom for PyGroupSchema { #[pymethods] impl PyGroupSchema { #[new] - #[pyo3(signature = (nodes, edges, strict=None))] - fn new( + pub fn new( nodes: HashMap, edges: HashMap, - strict: Option, - ) -> Self { - PyGroupSchema(GroupSchema { - nodes: nodes.deep_into(), - edges: edges.deep_into(), - strict, - }) + ) -> PyResult { + let nodes = nodes + .into_iter() + .map(|(k, v)| Ok((k.into(), v.try_into()?))) + .collect::, GraphError>>() + .map_err(PyMedRecordError::from)? + .into(); + let edges = edges + .into_iter() + .map(|(k, v)| Ok((k.into(), v.try_into()?))) + .collect::, GraphError>>() + .map_err(PyMedRecordError::from)? + .into(); + + Ok(Self(GroupSchema::new(nodes, edges))) } #[getter] - fn nodes(&self) -> HashMap { - self.0.nodes.clone().deep_into() + pub fn nodes(&self) -> HashMap { + self.0.nodes().clone().deep_into() } #[getter] - fn edges(&self) -> HashMap { - self.0.edges.clone().deep_into() + pub fn edges(&self) -> HashMap { + self.0.edges().clone().deep_into() } - #[getter] - fn strict(&self) -> Option { - self.0.strict + pub fn validate_node(&self, index: PyNodeIndex, attributes: PyAttributes) -> PyResult<()> { + Ok(self + .0 + .validate_node(&index.into(), &attributes.deep_into()) + .map_err(PyMedRecordError::from)?) + } + + pub fn validate_edge(&self, index: EdgeIndex, attributes: PyAttributes) -> PyResult<()> { + Ok(self + .0 + .validate_edge(&index, &attributes.deep_into()) + .map_err(PyMedRecordError::from)?) + } +} + +#[pyclass(eq, eq_int)] +#[derive(Debug, Clone, PartialEq)] +pub enum PySchemaType { + Provided = 0, + Inferred = 1, +} + +impl From for PySchemaType { + fn from(value: SchemaType) -> Self { + match value { + SchemaType::Provided => Self::Provided, + SchemaType::Inferred => Self::Inferred, + } + } +} + +impl From for SchemaType { + fn from(value: PySchemaType) -> Self { + match value { + PySchemaType::Provided => Self::Provided, + PySchemaType::Inferred => Self::Inferred, + } } } @@ -181,47 +228,190 @@ impl From for Schema { #[pymethods] impl PySchema { #[new] - #[pyo3(signature = (groups, default=None, strict=None))] - fn new( + #[pyo3(signature = (groups, default, schema_type=PySchemaType::Provided))] + pub fn new( groups: HashMap, - default: Option, - strict: Option, + default: PyGroupSchema, + schema_type: PySchemaType, ) -> Self { - PySchema(Schema { - groups: groups - .into_iter() - .map(|(k, v)| (k.into(), v.into())) - .collect(), - default: default.deep_into(), - strict, + Self(match schema_type { + PySchemaType::Provided => Schema::new_provided(groups.deep_into(), default.deep_into()), + PySchemaType::Inferred => Schema::new_inferred(groups.deep_into(), default.deep_into()), }) } - #[getter] - fn groups(&self) -> Vec { - self.0.groups.keys().map(|g| g.clone().into()).collect() + #[staticmethod] + pub fn infer(medrecord: PyMedRecord) -> Self { + Self(Schema::infer(&medrecord.into())) } - fn group(&self, group: PyGroup) -> PyResult { - let group = group.into(); + #[getter] + pub fn groups(&self) -> Vec { + self.0 + .groups() + .keys() + .cloned() + .collect::>() + .deep_into() + } + pub fn group(&self, group: PyGroup) -> PyResult { Ok(self .0 - .groups - .get(&group) + .group(&group.into()) .map(|g| g.clone().into()) - .ok_or(PyMedRecordError::from(MedRecordError::SchemaError( - format!("No schema found for group: {}", group), - )))?) + .map_err(PyMedRecordError::from)?) } #[getter] - fn default(&self) -> Option { - self.0.default.clone().map(|g| g.into()) + pub fn default(&self) -> PyGroupSchema { + self.0.default().clone().into() } #[getter] - fn strict(&self) -> Option { - self.0.strict + pub fn schema_type(&self) -> PySchemaType { + self.0.schema_type().clone().into() + } + + #[pyo3(signature = (index, attributes, group=None))] + pub fn validate_node( + &self, + index: PyNodeIndex, + attributes: PyAttributes, + group: Option, + ) -> PyResult<()> { + Ok(self + .0 + .validate_node( + &index.into(), + &attributes.deep_into(), + group.map(|g| g.into()).as_ref(), + ) + .map_err(PyMedRecordError::from)?) + } + + #[pyo3(signature = (index, attributes, group=None))] + pub fn validate_edge( + &self, + index: EdgeIndex, + attributes: PyAttributes, + group: Option, + ) -> PyResult<()> { + Ok(self + .0 + .validate_edge( + &index, + &attributes.deep_into(), + group.map(|g| g.into()).as_ref(), + ) + .map_err(PyMedRecordError::from)?) + } + + #[pyo3(signature = (attribute, data_type, attribute_type, group=None))] + pub fn set_node_attribute( + &mut self, + attribute: PyMedRecordAttribute, + data_type: PyDataType, + attribute_type: PyAttributeType, + group: Option, + ) -> PyResult<()> { + Ok(self + .0 + .set_node_attribute( + &attribute.into(), + data_type.into(), + attribute_type.into(), + group.map(|g| g.into()).as_ref(), + ) + .map_err(PyMedRecordError::from)?) + } + + #[pyo3(signature = (attribute, data_type, attribute_type, group=None))] + pub fn set_edge_attribute( + &mut self, + attribute: PyMedRecordAttribute, + data_type: PyDataType, + attribute_type: PyAttributeType, + group: Option, + ) -> PyResult<()> { + Ok(self + .0 + .set_edge_attribute( + &attribute.into(), + data_type.into(), + attribute_type.into(), + group.map(|g| g.into()).as_ref(), + ) + .map_err(PyMedRecordError::from)?) + } + + #[pyo3(signature = (attribute, data_type, attribute_type, group=None))] + pub fn update_node_attribute( + &mut self, + attribute: PyMedRecordAttribute, + data_type: PyDataType, + attribute_type: PyAttributeType, + group: Option, + ) -> PyResult<()> { + Ok(self + .0 + .update_node_attribute( + &attribute.into(), + data_type.into(), + attribute_type.into(), + group.map(|g| g.into()).as_ref(), + ) + .map_err(PyMedRecordError::from)?) + } + + #[pyo3(signature = (attribute, data_type, attribute_type, group=None))] + pub fn update_edge_attribute( + &mut self, + attribute: PyMedRecordAttribute, + data_type: PyDataType, + attribute_type: PyAttributeType, + group: Option, + ) -> PyResult<()> { + Ok(self + .0 + .update_edge_attribute( + &attribute.into(), + data_type.into(), + attribute_type.into(), + group.map(|g| g.into()).as_ref(), + ) + .map_err(PyMedRecordError::from)?) + } + + #[pyo3(signature = (attribute, group=None))] + pub fn remove_node_attribute( + &mut self, + attribute: PyMedRecordAttribute, + group: Option, + ) { + self.0 + .remove_node_attribute(&attribute.into(), group.map(|g| g.into()).as_ref()); + } + + #[pyo3(signature = (attribute, group=None))] + pub fn remove_edge_attribute( + &mut self, + attribute: PyMedRecordAttribute, + group: Option, + ) { + self.0 + .remove_edge_attribute(&attribute.into(), group.map(|g| g.into()).as_ref()); + } + + pub fn remove_group(&mut self, group: PyGroup) { + self.0.remove_group(&group.into()); + } + + pub fn freeze(&mut self) { + self.0.freeze(); + } + + pub fn unfreeze(&mut self) { + self.0.unfreeze(); } }