Skip to content

Commit

Permalink
feat: implement new functions
Browse files Browse the repository at this point in the history
  • Loading branch information
JabobKrauskopf committed Feb 12, 2025
1 parent ce84a11 commit 0af9fd4
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 165 deletions.
6 changes: 3 additions & 3 deletions medmodels/_medmodels.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ class PyMedRecord:
def clone(self) -> PyMedRecord: ...

class PyEdgeDirection(Enum):
Incoming = 0
Outgoing = 1
Both = 2
Incoming = ...
Outgoing = ...
Both = ...

class PyNodeOperand:
def attribute(self, attribute: MedRecordAttribute) -> PyMultipleValuesOperand: ...
Expand Down
211 changes: 209 additions & 2 deletions medmodels/medrecord/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,15 @@
from medmodels.medrecord.datatype import (
Union as DataTypeUnion,
)
from medmodels.medrecord.types import MedRecordAttribute
from medmodels.medrecord.types import (
Attributes,
EdgeIndex,
MedRecordAttribute,
NodeIndex,
)

if TYPE_CHECKING:
from medmodels.medrecord.medrecord import MedRecord
from medmodels.medrecord.types import Group


Expand Down Expand Up @@ -63,7 +69,22 @@ def _from_py_attribute_type(py_attribute_type: PyAttributeType) -> AttributeType
return AttributeType.Temporal
if py_attribute_type == PyAttributeType.Unstructured:
return AttributeType.Unstructured
return None
msg = "Should never be reached"
raise NotImplementedError(msg)

@staticmethod
def infer(data_type: DataType) -> AttributeType:
"""Infers the attribute type from the data type.
Args:
data_type (DataType): The data type to infer the attribute type from.
Returns:
AttributeType: The inferred attribute type.
"""
return AttributeType._from_py_attribute_type(
PyAttributeType.infer(data_type._inner())
)

def _into_py_attribute_type(self) -> PyAttributeType:
"""Converts an AttributeType to a PyAttributeType.
Expand All @@ -77,6 +98,8 @@ def _into_py_attribute_type(self) -> PyAttributeType:
return PyAttributeType.Continuous
if self == AttributeType.Temporal:
return PyAttributeType.Temporal
if self == AttributeType.Unstructured:
return PyAttributeType.Unstructured
msg = "Should never be reached"
raise NotImplementedError(msg)

Expand Down Expand Up @@ -267,6 +290,24 @@ def _convert_edge(
for x in self._group_schema.edges
}

def validate_node(self, index: NodeIndex, attributes: Attributes) -> None:
"""Validates the attributes of a node.
Args:
index (NodeIndex): The index of the node.
attributes (Attributes): The attributes of the node.
"""
self._group_schema.validate_node(index, attributes)

def validate_edge(self, index: EdgeIndex, attributes: Attributes) -> None:
"""Validates the attributes of an edge.
Args:
index (EdgeIndex): The index of the edge.
attributes (Attributes): The attributes of the edge.
"""
self._group_schema.validate_edge(index, attributes)


class Schema:
"""A schema for a collection of groups."""
Expand Down Expand Up @@ -298,6 +339,20 @@ def __init__(
default=default._group_schema,
)

@classmethod
def infer(cls, medrecord: MedRecord) -> Schema:
"""Infers a schema from a MedRecord instance.
Args:
medrecord (MedRecord): The MedRecord instance to infer the schema from.
Returns:
Schema: The inferred schema.
"""
new_schema = cls()
new_schema._schema = PySchema.infer(medrecord._medrecord)
return new_schema

@classmethod
def _from_py_schema(cls, schema: PySchema) -> Schema:
"""Creates a Schema instance from an existing PySchema.
Expand Down Expand Up @@ -344,3 +399,155 @@ def default(self) -> GroupSchema:
None.
"""
return GroupSchema._from_pygroupschema(self._schema.default)

def validate_node(
self, index: NodeIndex, attributes: Attributes, group: Optional[Group] = None
) -> None:
"""Validates the attributes of a node.
Args:
index (NodeIndex): The index of the node.
attributes (Attributes): The attributes of the node.
group (Optional[Group], optional): The group to validate the node against.
If not provided, the default group is used. Defaults to None.
"""
self._schema.validate_node(index, attributes, group)

def validate_edge(
self, index: EdgeIndex, attributes: Attributes, group: Optional[Group] = None
) -> None:
"""Validates the attributes of an edge.
Args:
index (EdgeIndex): The index of the edge.
attributes (Attributes): The attributes of the edge.
group (Optional[Group], optional): The group to validate the edge against.
If not provided, the default group is used. Defaults to None.
"""
self._schema.validate_edge(index, attributes, group)

def set_node_attribute(
self,
attribute: MedRecordAttribute,
data_type: DataType,
attribute_type: Optional[AttributeType] = None,
group: Optional[Group] = None,
) -> None:
"""Sets the data type and attribute type of a node attribute.
If a data type for the attribute already exists, it is overwritten.
Args:
attribute (MedRecordAttribute): The name of the attribute.
data_type (DataType): The data type of the attribute.
attribute_type (Optional[AttributeType], optional): The attribute type of
the attribute. If not provided, the attribute type is inferred
from the data type. Defaults to None.
group (Optional[Group], optional): The group to set the attribute for.
If no schema for the group exists, a new schema is created.
If not provided, the default group is used. Defaults to None.
"""
if not attribute_type:
attribute_type = AttributeType.infer(data_type)

self._schema.set_node_attribute(
attribute,
data_type._inner(),
attribute_type._into_py_attribute_type(),
group,
)

def set_edge_attribute(
self,
attribute: MedRecordAttribute,
data_type: DataType,
attribute_type: Optional[AttributeType] = None,
group: Optional[Group] = None,
) -> None:
"""Sets the data type and attribute type of an edge attribute.
If a data type for the attribute already exists, it is overwritten.
Args:
attribute (MedRecordAttribute): The name of the attribute.
data_type (DataType): The data type of the attribute.
attribute_type (Optional[AttributeType], optional): The attribute type of
the attribute. If not provided, the attribute type is inferred
from the data type. Defaults to None.
group (Optional[Group], optional): The group to set the attribute for.
If no schema for this group exists, a new schema is created.
If not provided, the default group is used. Defaults to None.
"""
if not attribute_type:
attribute_type = AttributeType.infer(data_type)

self._schema.set_edge_attribute(
attribute,
data_type._inner(),
attribute_type._into_py_attribute_type(),
group,
)

def update_node_attribute(
self,
attribute: MedRecordAttribute,
data_type: DataType,
attribute_type: Optional[AttributeType] = None,
group: Optional[Group] = None,
) -> None:
"""Updates the data type and attribute type of a node attribute.
If a data type for the attribute already exists, it is merged
with the new data type.
Args:
attribute (MedRecordAttribute): The name of the attribute.
data_type (DataType): The data type of the attribute.
attribute_type (Optional[AttributeType], optional): The attribute type of
the attribute. If not provided, the attribute type is inferred
from the data type. Defaults to None.
group (Optional[Group], optional): The group to update the attribute for.
If no schema for this group exists, a new schema is created.
If not provided, the default group is used. Defaults to None.
"""
if not attribute_type:
attribute_type = AttributeType.infer(data_type)

self._schema.update_node_attribute(
attribute,
data_type._inner(),
attribute_type._into_py_attribute_type(),
group,
)

def update_edge_attribute(
self,
attribute: MedRecordAttribute,
data_type: DataType,
attribute_type: Optional[AttributeType] = None,
group: Optional[Group] = None,
) -> None:
"""Updates the data type and attribute type of an edge attribute.
If a data type for the attribute already exists, it is merged
with the new data type.
Args:
attribute (MedRecordAttribute): The name of the attribute.
data_type (DataType): The data type of the attribute.
attribute_type (Optional[AttributeType], optional): The attribute type of
the attribute. If not provided, the attribute type is inferred
from the data type. Defaults to None.
group (Optional[Group], optional): The group to update the attribute for.
If no schema for this group exists, a new schema is created.
If not provided, the default group is used. Defaults to None.
"""
if not attribute_type:
attribute_type = AttributeType.infer(data_type)

self._schema.update_edge_attribute(
attribute,
data_type._inner(),
attribute_type._into_py_attribute_type(),
group,
)
Loading

0 comments on commit 0af9fd4

Please sign in to comment.