Skip to content

Commit

Permalink
use the correct library
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Jul 8, 2024
1 parent 1c45b95 commit 55c8eed
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
8 changes: 4 additions & 4 deletions bfabric/entities/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from typing import Any

from pandas import DataFrame
from polars import DataFrame


class Dataset:
Expand All @@ -28,9 +28,9 @@ def to_polars(self) -> DataFrame:
data.append(dict(zip(column_names, row_values)))
return DataFrame(data)

def write_csv(self, path: Path, sep: str = ",") -> None:
"""Writes the dataset to a csv file at `path`, using `sep` as the separator."""
self.to_polars().to_csv(path, sep=sep, index=False)
def write_csv(self, path: Path, separator: str = ",") -> None:
"""Writes the dataset to a csv file at `path`, using the specified column `separator`."""
self.to_polars().write_csv(path, separator=separator)

def __repr__(self) -> str:
"""Returns the string representation of the dataset."""
Expand Down
2 changes: 1 addition & 1 deletion bfabric/scripts/bfabric_save_dataset2csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def bfabric_save_dataset2csv(client: Bfabric, dataset_id: int, out_dir: Path, ou
dataset = Dataset(results[0])
output_path = out_dir / out_filename
try:
dataset.write_csv(output_path, sep=sep)
dataset.write_csv(output_path, separator=sep)
except Exception:
print(f"The writing process to '{output_path}' failed.")
raise
Expand Down
19 changes: 13 additions & 6 deletions bfabric/tests/unit/entities/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Any

import polars as pl
import polars.testing
import pytest
from pytest_mock import MockFixture

Expand Down Expand Up @@ -40,10 +42,15 @@ def test_dict(mock_dataset: Dataset, mock_data_dict: dict[str, Any]) -> None:

def test_to_polars(mock_dataset: Dataset) -> None:
df = mock_dataset.to_polars()
assert df.columns.to_list() == ["Color", "Shape"]
assert df.shape == (2, 2)
assert df["Color"].to_list() == ["Red", "Blue"]
assert df["Shape"].to_list() == ["Square", "Circle"]
pl.testing.assert_frame_equal(
df,
pl.DataFrame(
{
"Color": ["Red", "Blue"],
"Shape": ["Square", "Circle"],
}
),
)


def test_write_csv(mocker: MockFixture, mock_dataset: Dataset) -> None:
Expand All @@ -53,9 +60,9 @@ def test_write_csv(mocker: MockFixture, mock_dataset: Dataset) -> None:
mock_path = mocker.MagicMock(name="mock_path")
mock_sep = mocker.MagicMock(name="mock_sep")

mock_dataset.write_csv(mock_path, sep=mock_sep)
mock_dataset.write_csv(mock_path, separator=mock_sep)

mock_df.to_csv.assert_called_once_with(mock_path, sep=mock_sep, index=False)
mock_df.write_csv.assert_called_once_with(mock_path, separator=mock_sep)


def test_repr(mock_empty_dataset: Dataset) -> None:
Expand Down

0 comments on commit 55c8eed

Please sign in to comment.