Skip to content

Commit

Permalink
test: add test
Browse files Browse the repository at this point in the history
  • Loading branch information
bokajgd committed Oct 4, 2023
1 parent 5f97278 commit 4992d87
Showing 1 changed file with 56 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Tests for adding values to a flattened dataset."""

import datetime as dt

import numpy as np
import pandas as pd
import pytest

from timeseriesflattener import TimeseriesFlattener
from timeseriesflattener.aggregation_fns import maximum, minimum
from timeseriesflattener.aggregation_fns import latest, maximum, minimum
from timeseriesflattener.feature_specs.single_specs import (
OutcomeSpec,
PredictorSpec,
Expand Down Expand Up @@ -499,7 +500,6 @@ def test_add_temporal_predictors_then_temporal_outcome():
check_dtype=False,
)


def test_add_temporal_incident_binary_outcome():
prediction_times_str = """entity_id,timestamp,
1,2021-11-05 00:00:00
Expand Down Expand Up @@ -548,3 +548,57 @@ def test_add_temporal_incident_binary_outcome():
df[col] = df[col].astype("int32")

pd.testing.assert_series_equal(outcome_df[col], expected_df[col])


def test_add_outcome_timestamps():
prediction_times_str = """entity_id,timestamp,
1,2021-11-05 00:00:00
1,2021-11-01 00:00:00
1,2023-11-05 00:00:00
"""

event_times_str = """entity_id,timestamp,value,
1,2021-11-06 00:00:01,2021-11-06 00:00:01
1,2021-11-13 00:00:01,2021-11-13 00:00:01
"""

expected_df_str = """entity_id,outc_timestamp_within_10_days_latest_fallback_nan_dichotomous,
0,2021-11-13 00:00:01
1,2021-11-06 00:00:01
2,
"""

prediction_times_df = str_to_df(prediction_times_str)
event_times_df = str_to_df(event_times_str)
expected_df = str_to_df(expected_df_str)
expected_df['outc_timestamp_within_10_days_latest_fallback_nan_dichotomous'] = expected_df['outc_timestamp_within_10_days_latest_fallback_nan_dichotomous'].astype(str).replace('NaT', np.NaN)

flattened_dataset = TimeseriesFlattener(
prediction_times_df=prediction_times_df,
timestamp_col_name="timestamp",
entity_id_col_name="entity_id",
n_workers=4,
drop_pred_times_with_insufficient_look_distance=False,
)

flattened_dataset.add_spec(
spec=OutcomeSpec(
timeseries_df=event_times_df,
lookahead_days=10,
incident=False,
fallback=np.NaN,
feature_base_name="timestamp",
aggregation_fn=latest,
),
)

outcome_df = flattened_dataset.get_df()

for col in [c for c in expected_df.columns if "outc" in c]:
for df in (outcome_df, expected_df):
# Windows and Linux have different default dtypes for ints,
# which is not a meaningful error here. So we force the dtype.
if df[col].dtype == "int64":
df[col] = df[col].astype("int32")

pd.testing.assert_series_equal(outcome_df[col], expected_df[col])

0 comments on commit 4992d87

Please sign in to comment.