Skip to content

Commit

Permalink
Fix broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielguarisa committed Oct 25, 2023
1 parent ab62802 commit 0ec0ab7
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 25 deletions.
27 changes: 27 additions & 0 deletions retrack/engine/parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing

import hashlib

from retrack import nodes, validators
from retrack.utils.registry import Registry

Expand Down Expand Up @@ -29,10 +31,35 @@ def __init__(
self._set_execution_order()
self._set_indexes_by_memory_type_map()

self._version = self.graph_data.get("version", None)

if self._version is None:
self._version = "{}.dynamic".format(
hashlib.sha256(str(self.graph_data).encode("utf-8")).hexdigest()[:10],
)
else:
graph_data_without_version = self.graph_data.copy()
file_version_hash = graph_data_without_version["version"].split(".")[0]
del graph_data_without_version["version"]

if (
file_version_hash
!= hashlib.sha256(
str(graph_data_without_version).encode("utf-8")
).hexdigest()[:10]
):
raise ValueError(
"Invalid version. Graph data has changed and the hash is different"
)

@property
def graph_data(self) -> dict:
return self.__graph_data

@property
def version(self) -> str:
return self._version

@staticmethod
def _check_input_data(data: dict):
if not isinstance(data, dict):
Expand Down
19 changes: 4 additions & 15 deletions retrack/engine/request_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,8 @@ def __create_model(
**fields,
)

def __create_dataframe_model(
self, model_name: str = "RequestModel"
) -> pandera.DataFrameSchema:
"""Create a pydantic model from the RequestManager's inputs
Args:
model_name (str, optional): The name of the model. Defaults to "RequestModel".
Returns:
typing.Type[pydantic.BaseModel]: The pydantic model
"""
def __create_dataframe_model(self) -> pandera.DataFrameSchema:
"""Create a pydantic model from the RequestManager's inputs"""
fields = {}
for input_field in self.inputs:
fields[input_field.data.name] = pandera.Column(
Expand All @@ -113,14 +104,12 @@ def __create_dataframe_model(

def validate(
self,
payload: typing.Union[
typing.Dict[str, str], typing.List[typing.Dict[str, str]]
],
payload: pd.DataFrame,
) -> typing.List[pydantic.BaseModel]:
"""Validate the payload against the RequestManager's model
Args:
payload (typing.Union[typing.Dict[str, str], typing.List[typing.Dict[str, str]]]): The payload to validate
payload (pandas.DataFrame): The payload to validate
Raises:
ValueError: If the RequestManager has no model
Expand Down
15 changes: 13 additions & 2 deletions retrack/engine/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,21 @@ def __run_node(self, node_id: str):

def execute(
self,
payload_df: typing.Union[dict, pd.DataFrame],
payload_df: pd.DataFrame,
return_all_states: bool = False,
) -> pd.DataFrame:
"""Executes the flow with the given payload"""
"""Executes the flow with the given payload.
Args:
payload_df (pd.DataFrame): The payload to be used as input.
return_all_states (bool, optional): If True, returns all states. Defaults to False.
Returns:
pd.DataFrame: The output of the flow.
"""
if not isinstance(payload_df, pd.DataFrame):
raise ValueError("payload_df must be a pandas.DataFrame")

self.reset()
self._states = self._create_initial_state_from_payload(payload_df)

Expand Down
19 changes: 15 additions & 4 deletions tests/test_engine/test_request_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pandas as pd
import pandera
import pydantic
import pytest

from retrack.engine.request_manager import RequestManager
from retrack.nodes.inputs import Input
import pandas as pd


def test_create_request_manager(valid_input_dict_before_validation):
with pytest.raises(TypeError):
Expand Down Expand Up @@ -34,6 +37,14 @@ def test_create_request_manager_with_invalid_input(valid_input_dict_before_valid


def test_validate_payload_with_valid_payload(valid_input_dict_before_validation):
pm = RequestManager([Input(**valid_input_dict_before_validation)])
payload = pm.model(example="test")
assert pm.validate(pd.DataFrame([{{"example": "test"}}]))[0] == payload
rm = RequestManager([Input(**valid_input_dict_before_validation)])

assert issubclass(rm.model, pydantic.BaseModel)

assert isinstance(rm.dataframe_model, pandera.api.pandas.container.DataFrameSchema)

payload = rm.model(example="test")

assert isinstance(payload, pydantic.BaseModel)
result = rm.validate(pd.DataFrame([{"example": "test"}]))
assert isinstance(result, pd.DataFrame)
8 changes: 4 additions & 4 deletions tests/test_engine/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_flows_with_single_element(filename, in_values, expected_out_values):
rule = json.load(f)

runner = Runner(Parser(rule))
out_values = runner.execute(in_values)
out_values = runner.execute(pd.DataFrame([in_values]))

assert isinstance(out_values, pd.DataFrame)
assert out_values.to_dict(orient="records") == expected_out_values
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_flows(filename, in_values, expected_out_values):
rule = json.load(f)

runner = Runner(Parser(rule))
out_values = runner.execute(in_values)
out_values = runner.execute(pd.DataFrame(in_values))

assert isinstance(out_values, pd.DataFrame)
assert out_values.to_dict(orient="records") == expected_out_values
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_flows(filename, in_values, expected_out_values):
)
def test_create_from_json(filename, in_values, expected_out_values):
runner = Runner.from_json(f"tests/resources/{filename}.json")
out_values = runner.execute(in_values)
out_values = runner.execute(pd.DataFrame(in_values))

assert isinstance(out_values, pd.DataFrame)
assert out_values.to_dict(orient="records") == expected_out_values
Expand All @@ -171,7 +171,7 @@ def test_csv_table_with_if():
{"in_a": 1, "in_b": 1, "in_d": -1, "in_e": 0},
]

out_values = runner.execute(in_values)
out_values = runner.execute(pd.DataFrame(in_values))

assert isinstance(out_values, pd.DataFrame)
assert len(out_values) == len(in_values)
Expand Down

0 comments on commit 0ec0ab7

Please sign in to comment.