Skip to content

Commit

Permalink
wavg with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
oelarnes committed Jan 11, 2025
1 parent 4138f87 commit 52a2309
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 5 deletions.
19 changes: 14 additions & 5 deletions spells/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,26 @@ def wavg(
name_list = list(new_names)

assert len(name_list) == len(col_list), f"{len(name_list)} names provided for {len(col_list)} columns"
assert len(weight_list) == len(col_list), f"{len(weight_list)} weights provided for {len(col_list)} columns"
assert len(name_list) == len(set(name_list)), "Output names must be unique"
assert len(weight_list) == len(col_list) or len(weight_list) == 1, f"{len(weight_list)} weights provided for {len(col_list)} columns"

enum_wl = weight_list * int(len(col_list) / len(weight_list))
wl_names = [w.meta.output_name() for w in weight_list]
assert len(wl_names) == len(set(wl_names)), "Weights must have unique names. Send one weight column or n uniquely named ones"

to_group = df.select(gbs + weight_list + [
(c * weight_list[i]) for i, c in enumerate(col_list)
(c * enum_wl[i]).alias(name_list[i]) for i, c in enumerate(col_list)
])

grouped = to_group if not gbs else to_group.group_by(gbs)

return grouped.sum().select(
ret_df = grouped.sum().select(
gbs +
[pl.col(c.meta.output_name()).alias(name_list[i]) for i, c in enumerate(col_list)]
wl_names +
[(pl.col(name) / pl.col(enum_wl[i].meta.output_name())) for i, name in enumerate(name_list)]
)

if gbs:
ret_df = ret_df.sort(by=gbs)


return ret_df
174 changes: 174 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@

"""
Test behavior of wavg utility for Polars DataFrames
"""

import pytest
import polars as pl

import spells.utils as utils

def format_test_string(test_string: str) -> str:
"""
strip whitespace from each line to test pasted dataframe outputs
"""
return "\n".join(
[line.strip() for line in test_string.splitlines() if line.strip()]
)

test_df = pl.DataFrame({
'cat': ['a', 'a', 'b', 'b', 'b', 'c' ],
'va1': [1.0, -1.0, 0.2, 0.4, 0.0, 10.0 ],
'va2': [4.0, 3.0, 1.0, -2.0, 2.0, 1.0 ],
'wt1': [1, 2, 0, 2, 3, 1 ],
'wt2': [2, 4, 1, 1, 1, 2, ],
})


# test wavg with default args
@pytest.mark.parametrize(
"cols, weights, expected",
[
(
'va1',
'wt1',
"""
shape: (1, 2)
┌─────┬──────────┐
│ wt1 ┆ va1 │
│ --- ┆ --- │
│ i64 ┆ f64 │
╞═════╪══════════╡
│ 9 ┆ 1.088889 │
└─────┴──────────┘
"""
),
(
['va1', 'va2'],
'wt1',
"""
shape: (1, 3)
┌─────┬──────────┬──────────┐
│ wt1 ┆ va1 ┆ va2 │
│ --- ┆ --- ┆ --- │
│ i64 ┆ f64 ┆ f64 │
╞═════╪══════════╪══════════╡
│ 9 ┆ 1.088889 ┆ 1.444444 │
└─────┴──────────┴──────────┘
"""
),
(
['va1', 'va2'],
['wt1', 'wt2'],
"""
shape: (1, 4)
┌─────┬─────┬──────────┬──────────┐
│ wt1 ┆ wt2 ┆ va1 ┆ va2 │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ f64 ┆ f64 │
╞═════╪═════╪══════════╪══════════╡
│ 9 ┆ 11 ┆ 1.088889 ┆ 2.090909 │
└─────┴─────┴──────────┴──────────┘
"""
),
(
[pl.col('va1') + 1, 'va2'],
['wt1', pl.col('wt2') + 1],
"""
shape: (1, 4)
┌─────┬─────┬──────────┬──────────┐
│ wt1 ┆ wt2 ┆ va1 ┆ va2 │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ f64 ┆ f64 │
╞═════╪═════╪══════════╪══════════╡
│ 9 ┆ 17 ┆ 2.088889 ┆ 1.882353 │
└─────┴─────┴──────────┴──────────┘
"""
),
]
)
def test_wavg_defaults(cols: str | pl.Expr | list[str | pl.Expr], weights: str | pl.Expr | list[str | pl.Expr], expected: str):
result = utils.wavg(test_df, cols, weights)

test_str = str(result)
print(test_str)
assert test_str == format_test_string(expected)


# test wavg with named args
@pytest.mark.parametrize(
"cols, weights, group_by, new_names, expected",
[
(
"va1",
"wt1",
[],
"v1",
"""
shape: (1, 2)
┌─────┬──────────┐
│ wt1 ┆ v1 │
│ --- ┆ --- │
│ i64 ┆ f64 │
╞═════╪══════════╡
│ 9 ┆ 1.088889 │
└─────┴──────────┘
"""
),
(
"va1",
"wt1",
"cat",
"va1",
"""
shape: (3, 3)
┌─────┬─────┬───────────┐
│ cat ┆ wt1 ┆ va1 │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ f64 │
╞═════╪═════╪═══════════╡
│ a ┆ 3 ┆ -0.333333 │
│ b ┆ 5 ┆ 0.16 │
│ c ┆ 1 ┆ 10.0 │
└─────┴─────┴───────────┘
"""
),
(
["va1", "va1"],
["wt1", "wt2"],
["cat"],
["v@1", "v@2"],
"""
shape: (3, 5)
┌─────┬─────┬─────┬───────────┬───────────┐
│ cat ┆ wt1 ┆ wt2 ┆ v@1 ┆ v@2 │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ f64 ┆ f64 │
╞═════╪═════╪═════╪═══════════╪═══════════╡
│ a ┆ 3 ┆ 6 ┆ -0.333333 ┆ -0.333333 │
│ b ┆ 5 ┆ 3 ┆ 0.16 ┆ 0.2 │
│ c ┆ 1 ┆ 2 ┆ 10.0 ┆ 10.0 │
└─────┴─────┴─────┴───────────┴───────────┘
"""
)
]
)
def test_wavg(
cols: str | pl.Expr | list[str | pl.Expr],
weights: str | pl.Expr | list[str | pl.Expr],
group_by: str | pl.Expr | list[str | pl.Expr],
new_names: str | list[str],
expected: str,
):
result = utils.wavg(
test_df,
cols,
weights,
group_by=group_by,
new_names=new_names,
)

test_str = str(result)
print(test_str)
assert test_str == format_test_string(expected)

0 comments on commit 52a2309

Please sign in to comment.