Skip to content

Commit

Permalink
Retrack only supports dataframes from now on
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielguarisa committed Oct 11, 2023
1 parent 2288394 commit ab62802
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 49 deletions.
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ pandas = "^1.2.0"
numpy = "^1.19.5"
pydantic = "2.4.2"
networkx = "^2.6.3"
pandera = "^0.17.2"

[tool.poetry.dev-dependencies]
pytest = "^6.2.4"
pytest-cov = "^3.0.0"
pytest = "*"
pytest-cov = "*"
black = "^22.6.0"
isort = {extras = ["colors"], version = "*"}
pytest-mock = "^3.10.0"
pytest-mock = "*"

[tool.black]
# https://github.com/psf/black
Expand Down Expand Up @@ -72,7 +73,7 @@ indent = 4
color_output = true

[tool.pytest.ini_options]
addopts = "--junitxml=pytest.xml -p no:warnings --cov-report term-missing:skip-covered --cov=retrack"
addopts = "-vv --junitxml=pytest.xml -p no:warnings --cov-report term-missing:skip-covered --cov=retrack"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
51 changes: 45 additions & 6 deletions retrack/engine/request_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing

import pandas as pd
import pandera
import pydantic

from retrack.nodes.base import BaseNode, NodeKind
Expand All @@ -8,6 +10,7 @@
class RequestManager:
def __init__(self, inputs: typing.List[BaseNode]):
self._model = None
self._dataframe_model = None
self.inputs = inputs

@property
Expand Down Expand Up @@ -41,13 +44,19 @@ def inputs(self, inputs: typing.List[BaseNode]):

if len(self.inputs) > 0:
self._model = self.__create_model()
self._dataframe_model = self.__create_dataframe_model()
else:
self._model = None
self._dataframe_model = None

@property
def model(self) -> typing.Type[pydantic.BaseModel]:
return self._model

@property
def dataframe_model(self) -> pandera.DataFrameSchema:
return self._dataframe_model

def __create_model(
self, model_name: str = "RequestModel"
) -> typing.Type[pydantic.BaseModel]:
Expand All @@ -62,16 +71,46 @@ def __create_model(
fields = {}
for input_field in self.inputs:
fields[input_field.data.name] = (
(str, ...)
if input_field.data.default is None
else (str, input_field.data.default)
str,
pydantic.Field(
default=Ellipsis
if input_field.data.default is None
else input_field.data.default,
),
)

return pydantic.create_model(
model_name,
**fields,
)

def __create_dataframe_model(
self, model_name: str = "RequestModel"
) -> pandera.DataFrameSchema:
"""Create a pydantic model from the RequestManager's inputs
Args:
model_name (str, optional): The name of the model. Defaults to "RequestModel".
Returns:
typing.Type[pydantic.BaseModel]: The pydantic model
"""
fields = {}
for input_field in self.inputs:
fields[input_field.data.name] = pandera.Column(
str,
nullable=input_field.data.default is not None,
coerce=True,
default=input_field.data.default,
)

return pandera.DataFrameSchema(
fields,
index=pandera.Index(int),
strict=True,
coerce=True,
)

def validate(
self,
payload: typing.Union[
Expand All @@ -92,7 +131,7 @@ def validate(
if self.model is None:
raise ValueError("No inputs found")

if not isinstance(payload, list):
payload = [payload]
if not isinstance(payload, pd.DataFrame):
raise TypeError(f"payload must be a pandas.DataFrame, not {type(payload)}")

return pydantic.parse_obj_as(typing.List[self.model], payload)
return self.dataframe_model.validate(payload)
49 changes: 19 additions & 30 deletions retrack/engine/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,12 @@ def __set_output_connection_filters(
)

def _create_initial_state_from_payload(
self, payload: typing.Union[dict, list]
self, payload_df: pd.DataFrame
) -> pd.DataFrame:
"""Create initial state from payload. This is the first step of the runner."""
validated_payload = self.request_manager.validate(payload)
validated_payload = pd.DataFrame([p.model_dump() for p in validated_payload])
validated_payload = self.request_manager.validate(
payload_df.reset_index(drop=True)
)

state_df = pd.DataFrame([])
for node_id, input_name in self.input_columns.items():
Expand Down Expand Up @@ -186,34 +187,14 @@ def __run_node(self, node_id: str):
f"{node_id}@{output_name}", output_value, current_node_filter
)

def __get_output_states(self) -> pd.DataFrame:
"""Returns a dataframe with the final states of the flow"""
return pd.DataFrame(
{
"output": self.states[constants.OUTPUT_REFERENCE_COLUMN],
"message": self.states[constants.OUTPUT_MESSAGE_REFERENCE_COLUMN],
}
)

def __parse_payload(
self, payload: typing.Union[dict, list, pd.DataFrame]
) -> typing.List[dict]:
if isinstance(payload, dict):
payload = [payload]

if not isinstance(payload, pd.DataFrame):
payload = pd.DataFrame(payload, index=list(range(len(payload))))

for column in payload.columns:
payload[column] = payload[column].astype(str)

return payload.to_dict("records")

def execute(self, payload: typing.Union[dict, list, pd.DataFrame]) -> pd.DataFrame:
def execute(
self,
payload_df: typing.Union[dict, pd.DataFrame],
return_all_states: bool = False,
) -> pd.DataFrame:
"""Executes the flow with the given payload"""
self.reset()
payload = self.__parse_payload(payload)
self._states = self._create_initial_state_from_payload(payload)
self._states = self._create_initial_state_from_payload(payload_df)

for node_id in self.parser.execution_order:
try:
Expand All @@ -224,4 +205,12 @@ def execute(self, payload: typing.Union[dict, list, pd.DataFrame]) -> pd.DataFra
if self.states[constants.OUTPUT_REFERENCE_COLUMN].isna().sum() == 0:
break

return self.__get_output_states()
if return_all_states:
return self.states

return self.states[
[
constants.OUTPUT_REFERENCE_COLUMN,
constants.OUTPUT_MESSAGE_REFERENCE_COLUMN,
]
]
18 changes: 13 additions & 5 deletions retrack/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@ class NodeMemoryType(str, enum.Enum):
# Connection Models
###############################################################


def cast_int_to_str(v: typing.Any, info: pydantic.ValidationInfo) -> str:
return str(v)

CastedToStringType = typing.Annotated[typing.Any, pydantic.BeforeValidator(cast_int_to_str)]

CastedToStringType = typing.Annotated[
typing.Any, pydantic.BeforeValidator(cast_int_to_str)
]


class OutputConnectionItemModel(pydantic.BaseModel):
node: CastedToStringType
Expand All @@ -61,12 +66,15 @@ class InputConnectionModel(pydantic.BaseModel):
# Base Node
###############################################################


class BaseNode(pydantic.BaseModel):
id: CastedToStringType
inputs: typing.Optional[typing.Dict[CastedToStringType, InputConnectionModel]] = None
outputs: typing.Optional[typing.Dict[CastedToStringType, OutputConnectionModel]] = None


inputs: typing.Optional[
typing.Dict[CastedToStringType, InputConnectionModel]
] = None
outputs: typing.Optional[
typing.Dict[CastedToStringType, OutputConnectionModel]
] = None

def run(self, **kwargs) -> typing.Dict[str, typing.Any]:
return {}
Expand Down
4 changes: 2 additions & 2 deletions retrack/utils/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
OUTPUT_REFERENCE_COLUMN = "graph_output"
OUTPUT_MESSAGE_REFERENCE_COLUMN = "graph_output_message"
OUTPUT_REFERENCE_COLUMN = "output"
OUTPUT_MESSAGE_REFERENCE_COLUMN = "message"
NULL_SUFFIX = "_void"
FILTER_SUFFIX = "_filter"
INPUT_OUTPUT_VALUE_CONNECTOR_NAME = "output_value"
Expand Down
Empty file removed tests/__init__.py
Empty file.
Empty file removed tests/test_engine/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions tests/test_engine/test_request_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from retrack.engine.request_manager import RequestManager
from retrack.nodes.inputs import Input

import pandas as pd

def test_create_request_manager(valid_input_dict_before_validation):
with pytest.raises(TypeError):
Expand Down Expand Up @@ -36,4 +36,4 @@ def test_create_request_manager_with_invalid_input(valid_input_dict_before_valid
def test_validate_payload_with_valid_payload(valid_input_dict_before_validation):
pm = RequestManager([Input(**valid_input_dict_before_validation)])
payload = pm.model(example="test")
assert pm.validate({"example": "test"})[0] == payload
assert pm.validate(pd.DataFrame([{{"example": "test"}}]))[0] == payload
Empty file removed tests/test_nodes/__init__.py
Empty file.

0 comments on commit ab62802

Please sign in to comment.