Skip to content

Commit

Permalink
feat: support polars lazyframe in add_lags (#675)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored May 26, 2024
1 parent fbb8e57 commit 0964662
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions sklego/pandas_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def add_lags(X, cols, lags, drop_na=True):
X = nw.from_native(X, strict=False)
allowed_inputs = {
nw.DataFrame: _add_lagged_dataframe_columns,
nw.LazyFrame: _add_lagged_dataframe_columns,
np.ndarray: _add_lagged_numpy_columns,
}

Expand Down
6 changes: 5 additions & 1 deletion tests/test_pandas_utils/test_pandas_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ def test_add_lags_wrong_inputs(data, frame_func):
add_lags(invalid_df, ["X1"], 1)


@pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame])
@pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame, pl.LazyFrame])
def test_add_lags_correct_df(data, frame_func):
test_df = frame_func(data)
expected = frame_func({"X1": [1, 2], "X2": ["178", "154"], "X1-1": [0, 1]})
ans = add_lags(test_df, "X1", -1)
if isinstance(ans, pl.LazyFrame):
ans = ans.collect()
if isinstance(expected, pl.LazyFrame):
expected = expected.collect()
assert [x for x in ans.columns] == [x for x in expected.columns]
assert (ans.to_numpy() == expected.to_numpy()).all()

Expand Down

0 comments on commit 0964662

Please sign in to comment.