Skip to content

Commit

Permalink
refactor: improve edge case handling
Browse files Browse the repository at this point in the history
  • Loading branch information
JabobKrauskopf committed Mar 5, 2025
1 parent 443679f commit 5089833
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 93 deletions.
8 changes: 4 additions & 4 deletions crates/medmodels-core/src/medrecord/group_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ pub type Group = MedRecordAttribute;

#[derive(Debug, Serialize, Deserialize, Clone)]
pub(super) struct GroupMapping {
nodes_in_group: MrHashMap<Group, MrHashSet<NodeIndex>>,
edges_in_group: MrHashMap<Group, MrHashSet<EdgeIndex>>,
groups_of_node: MrHashMap<NodeIndex, MrHashSet<Group>>,
groups_of_edge: MrHashMap<EdgeIndex, MrHashSet<Group>>,
pub(super) nodes_in_group: MrHashMap<Group, MrHashSet<NodeIndex>>,
pub(super) edges_in_group: MrHashMap<Group, MrHashSet<EdgeIndex>>,
pub(super) groups_of_node: MrHashMap<NodeIndex, MrHashSet<Group>>,
pub(super) groups_of_edge: MrHashMap<EdgeIndex, MrHashSet<Group>>,
}

impl GroupMapping {
Expand Down
111 changes: 98 additions & 13 deletions crates/medmodels-core/src/medrecord/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ use group_mapping::GroupMapping;
use polars::{dataframe_to_edges, dataframe_to_nodes};
use querying::{edges::EdgeSelection, nodes::NodeSelection};
use serde::{Deserialize, Serialize};
use std::{fs, mem, path::Path};
use std::{
collections::{hash_map::Entry, HashMap},
fs, mem,
path::Path,
};

pub struct NodeDataFrameInput {
dataframe: DataFrame,
Expand Down Expand Up @@ -242,6 +246,11 @@ impl MedRecord {
}

pub fn update_schema(&mut self, mut schema: Schema) -> Result<(), MedRecordError> {
let mut nodes_group_cache = HashMap::<&Group, usize>::new();
let mut nodes_default_visited = false;
let mut edges_group_cache = HashMap::<&Group, usize>::new();
let mut edges_default_visited = false;

for (node_index, node) in self.graph.nodes.iter() {
let groups_of_node = self
.groups_of_node(node_index)
Expand All @@ -251,9 +260,24 @@ impl MedRecord {
if !groups_of_node.is_empty() {
for group in groups_of_node {
match schema.schema_type() {
SchemaType::Inferred => {
schema.update_node(&node.attributes, Some(group));
}
SchemaType::Inferred => match nodes_group_cache.entry(group) {
Entry::Occupied(entry) => {
schema.update_node(
&node.attributes,
Some(group),
*entry.get() == 0,
);
}
Entry::Vacant(entry) => {
entry.insert(
self.group_mapping
.nodes_in_group
.get(group)
.map(|nodes| nodes.len())
.unwrap_or(0),
);
}
},
SchemaType::Provided => {
schema.validate_node(node_index, &node.attributes, Some(group))?
}
Expand All @@ -262,7 +286,17 @@ impl MedRecord {
} else {
match schema.schema_type() {
SchemaType::Inferred => {
schema.update_node(&node.attributes, None);
let nodes_in_groups = self.group_mapping.nodes_in_group.len();

let nodes_not_in_groups = self.graph.node_count() - nodes_in_groups;

schema.update_node(
&node.attributes,
None,
nodes_not_in_groups == 0 || !nodes_default_visited,
);

nodes_default_visited = true;
}
SchemaType::Provided => {
schema.validate_node(node_index, &node.attributes, None)?;
Expand All @@ -280,9 +314,24 @@ impl MedRecord {
if !groups_of_edge.is_empty() {
for group in groups_of_edge {
match schema.schema_type() {
SchemaType::Inferred => {
schema.update_edge(&edge.attributes, Some(group));
}
SchemaType::Inferred => match edges_group_cache.entry(group) {
Entry::Occupied(entry) => {
schema.update_edge(
&edge.attributes,
Some(group),
*entry.get() == 0,
);
}
Entry::Vacant(entry) => {
entry.insert(
self.group_mapping
.edges_in_group
.get(group)
.map(|edges| edges.len())
.unwrap_or(0),
);
}
},
SchemaType::Provided => {
schema.validate_edge(edge_index, &edge.attributes, Some(group))?;
}
Expand All @@ -291,7 +340,17 @@ impl MedRecord {
} else {
match schema.schema_type() {
SchemaType::Inferred => {
schema.update_edge(&edge.attributes, None);
let edges_in_groups = self.group_mapping.edges_in_group.len();

let edges_not_in_groups = self.graph.edge_count() - edges_in_groups;

schema.update_edge(
&edge.attributes,
None,
edges_not_in_groups == 0 || !edges_default_visited,
);

edges_default_visited = true;
}
SchemaType::Provided => {
schema.validate_edge(edge_index, &edge.attributes, None)?;
Expand Down Expand Up @@ -414,7 +473,12 @@ impl MedRecord {
) -> Result<(), MedRecordError> {
match self.schema.schema_type() {
SchemaType::Inferred => {
self.schema.update_node(&attributes, None);
let nodes_in_groups = self.group_mapping.nodes_in_group.len();

let nodes_not_in_groups = self.graph.node_count() - nodes_in_groups;

self.schema
.update_node(&attributes, None, nodes_not_in_groups == 0);
}
SchemaType::Provided => {
self.schema.validate_node(&node_index, &attributes, None)?;
Expand Down Expand Up @@ -474,7 +538,12 @@ impl MedRecord {

match self.schema.schema_type() {
SchemaType::Inferred => {
self.schema.update_edge(&attributes, None);
let edges_in_groups = self.group_mapping.edges_in_group.len();

let edges_not_in_groups = self.graph.edge_count() - edges_in_groups;

self.schema
.update_edge(&attributes, None, edges_not_in_groups == 0);

Ok(edge_index)
}
Expand Down Expand Up @@ -616,7 +685,15 @@ impl MedRecord {

match self.schema.schema_type() {
SchemaType::Inferred => {
self.schema.update_node(node_attributes, Some(&group));
let nodes_in_group = self
.group_mapping
.nodes_in_group
.get(&group)
.map(|nodes| nodes.len())
.unwrap_or(0);

self.schema
.update_node(node_attributes, Some(&group), nodes_in_group == 0);
}
SchemaType::Provided => {
self.schema
Expand All @@ -636,7 +713,15 @@ impl MedRecord {

match self.schema.schema_type() {
SchemaType::Inferred => {
self.schema.update_edge(edge_attributes, Some(&group));
let edges_in_group = self
.group_mapping
.edges_in_group
.get(&group)
.map(|edges| edges.len())
.unwrap_or(0);

self.schema
.update_edge(edge_attributes, Some(&group), edges_in_group == 0);
}
SchemaType::Provided => {
self.schema
Expand Down
Loading

0 comments on commit 5089833

Please sign in to comment.