Skip to content

Commit

Permalink
update for exports via cli
Browse files Browse the repository at this point in the history
  • Loading branch information
d33bs committed Jun 19, 2024
1 parent 11cd3b0 commit 90c2088
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/cosmicqc/scdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
# if data is a pd.DataFrame, remember this within the data_source attr
self.data_source = "pandas.DataFrame"
self.data = data

elif isinstance(data, pd.Series):
# if data is a pd.DataFrame, remember this within the data_source attr
self.data_source = "pandas.Series"
Expand Down
57 changes: 48 additions & 9 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
Tests cosmicqc cli module
"""

from .utils import run_cli_command
import pathlib

from pyarrow import parquet

from .utils import run_cli_command


def test_cli_util():
"""
Test the `identify_outliers` function of the CLI.
Expand All @@ -27,7 +30,6 @@ def test_cli_identify_outliers(tmp_path: pathlib.Path, basic_outlier_csv: str):
f"""cosmicqc identify_outliers --df {basic_outlier_csv}"""
""" --feature_thresholds {"example_feature":1.0}"""
f" --export_path {tmp_path}/identify_outliers_output.parquet"

)
)

Expand All @@ -48,12 +50,25 @@ def test_cli_identify_outliers(tmp_path: pathlib.Path, basic_outlier_csv: str):
Name: Z_Score_example_feature, dtype: bool""".strip()
)


print(parquet.read_table(f"{tmp_path}/identify_outliers_output.parquet").to_pydict())
assert parquet.read_table(f"{tmp_path}/identify_outliers_output.parquet").to_pydict() == {}


def test_cli_find_outliers(basic_outlier_csv: str):
assert parquet.read_table(
f"{tmp_path}/identify_outliers_output.parquet"
).to_pydict() == {
"Z_Score_example_feature": [
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
]
}


def test_cli_find_outliers(tmp_path: pathlib.Path, basic_outlier_csv: str):
"""
Test the `find_outliers` function of the CLI.
"""
Expand All @@ -62,6 +77,7 @@ def test_cli_find_outliers(basic_outlier_csv: str):
(
f"""cosmicqc find_outliers --df {basic_outlier_csv}"""
""" --metadata_columns [] --feature_thresholds {"example_feature":1.0}"""
f" --export_path {tmp_path}/find_outliers_output.parquet"
)
)

Expand All @@ -77,8 +93,12 @@ def test_cli_find_outliers(basic_outlier_csv: str):
9 10""".strip()
)

assert parquet.read_table(
f"{tmp_path}/find_outliers_output.parquet"
).to_pydict() == {"example_feature": [9, 10], "__index_level_0__": [8, 9]}

def test_cli_label_outliers(basic_outlier_csv: str):

def test_cli_label_outliers(tmp_path: pathlib.Path, basic_outlier_csv: str):
"""
Test the `label_outliers` function of the CLI.
"""
Expand All @@ -87,6 +107,7 @@ def test_cli_label_outliers(basic_outlier_csv: str):
(
f"""cosmicqc label_outliers --df {basic_outlier_csv}"""
""" --feature_thresholds {"example_feature":1.0}"""
f" --export_path {tmp_path}/label_outliers_output.parquet"
)
)

Expand All @@ -105,3 +126,21 @@ def test_cli_label_outliers(basic_outlier_csv: str):
8 9 True
9 10 True""".strip()
)

assert parquet.read_table(
f"{tmp_path}/label_outliers_output.parquet"
).to_pydict() == {
"example_feature": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"outlier_custom": [
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
],
}

0 comments on commit 90c2088

Please sign in to comment.