Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: stabilize Row class #980

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class _LazyVectorizedRow(Row):
up operations on the row.
Moreover, accessing a column only builds an expression that will be evaluated when needed. This is useful when later
operations remove more rows or columns, so we don't do unnecessary work upfront.
operations remove rows or columns, so we don't do unnecessary work upfront.
"""

# ------------------------------------------------------------------------------------------------------------------
Expand Down
35 changes: 18 additions & 17 deletions src/safeds/data/tabular/containers/_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,29 @@

from abc import ABC, abstractmethod
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from ._cell import Cell

if TYPE_CHECKING:
from safeds.data.tabular.typing import ColumnType, Schema

from ._cell import Cell


class Row(ABC, Mapping[str, Any]):
class Row(ABC, Mapping[str, Cell]):
"""
A one-dimensional collection of named, heterogeneous values.
This class cannot be instantiated directly. It is only used for arguments of callbacks.
You only need to interact with this class in callbacks passed to higher-order functions.
"""

# ------------------------------------------------------------------------------------------------------------------
# Dunder methods
# ------------------------------------------------------------------------------------------------------------------

def __contains__(self, name: Any) -> bool:
return self.has_column(name)
def __contains__(self, key: object, /) -> bool:
if not isinstance(key, str):
return False
return self.has_column(key)

@abstractmethod
def __eq__(self, other: object) -> bool: ...
Expand All @@ -33,7 +35,7 @@ def __getitem__(self, name: str) -> Cell:
@abstractmethod
def __hash__(self) -> int: ...

def __iter__(self) -> Iterator[Any]:
def __iter__(self) -> Iterator[str]:
return iter(self.column_names)

def __len__(self) -> int:
Expand All @@ -48,18 +50,18 @@ def __sizeof__(self) -> int: ...

@property
@abstractmethod
def column_names(self) -> list[str]:
"""The names of the columns in the row."""
def column_count(self) -> int:
"""The number of columns."""

@property
@abstractmethod
def column_count(self) -> int:
"""The number of columns in the row."""
def column_names(self) -> list[str]:
"""The names of the columns."""

@property
@abstractmethod
def schema(self) -> Schema:
"""The schema of the row."""
"""The schema, which is a mapping from column names to their types."""

# ------------------------------------------------------------------------------------------------------------------
# Column operations
Expand Down Expand Up @@ -98,7 +100,6 @@ def get_cell(self, name: str) -> Cell:
| 2 | 4 |
+------+------+
>>> table.remove_rows(lambda row: row["col1"] == 1)
+------+------+
| col1 | col2 |
Expand All @@ -112,7 +113,7 @@ def get_cell(self, name: str) -> Cell:
@abstractmethod
def get_column_type(self, name: str) -> ColumnType:
"""
Get the type of the specified column.
Get the type of a column. This is equivalent to using the `[]` operator (indexed access).
Parameters
----------
Expand All @@ -127,13 +128,13 @@ def get_column_type(self, name: str) -> ColumnType:
Raises
------
ColumnNotFoundError
If the column name does not exist.
If the column does not exist.
"""

@abstractmethod
def has_column(self, name: str) -> bool:
"""
Check if the row has a column with the specified name.
Check if the row has a column with a specific name. This is equivalent to using the `in` operator.
Parameters
----------
Expand Down
24 changes: 12 additions & 12 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,41 +393,41 @@ def _data_frame(self) -> pl.DataFrame:
return self.__data_frame_cache

@property
def column_names(self) -> list[str]:
def column_count(self) -> int:
"""
The names of the columns in the table.
The number of columns.
**Note:** This operation must compute the schema of the table, which can be expensive.
Examples
--------
>>> from safeds.data.tabular.containers import Table
>>> table = Table({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> table.column_names
['a', 'b']
>>> table.column_count
2
"""
return self.schema.column_names
return len(self.column_names)

@property
def column_count(self) -> int:
def column_names(self) -> list[str]:
"""
The number of columns in the table.
The names of the columns in the table.
**Note:** This operation must compute the schema of the table, which can be expensive.
Examples
--------
>>> from safeds.data.tabular.containers import Table
>>> table = Table({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> table.column_count
2
>>> table.column_names
['a', 'b']
"""
return len(self.column_names)
return self.schema.column_names

@property
def row_count(self) -> int:
"""
The number of rows in the table.
The number of rows.
**Note:** This operation must fully load the data into memory, which can be expensive.
Expand Down Expand Up @@ -458,7 +458,7 @@ def plot(self) -> TablePlotter:
@property
def schema(self) -> Schema:
"""
The schema of the table.
The schema, which is a mapping from column names to their types.
Examples
--------
Expand Down
33 changes: 24 additions & 9 deletions src/safeds/data/tabular/typing/_schema.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from __future__ import annotations

import sys
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING

from safeds._utils import _structural_hash
from safeds._validation import _check_columns_exist

from ._column_type import ColumnType
from ._polars_column_type import _PolarsColumnType

if TYPE_CHECKING:
from collections.abc import Mapping

import polars as pl

from ._column_type import ColumnType


class Schema:
class Schema(Mapping[str, ColumnType]):
"""The schema of a row or table."""

# ------------------------------------------------------------------------------------------------------------------
Expand All @@ -41,16 +39,30 @@ def __init__(self, schema: Mapping[str, ColumnType]) -> None:
check_dtypes=False,
)

def __contains__(self, key: object, /) -> bool:
if not isinstance(key, str):
return False
return self.has_column(key)

def __eq__(self, other: object) -> bool:
if not isinstance(other, Schema):
return NotImplemented
if self is other:
return True
return self._schema == other._schema

def __getitem__(self, key: str, /) -> ColumnType:
return self.get_column_type(key)

def __hash__(self) -> int:
return _structural_hash(tuple(self._schema.keys()), [str(type_) for type_ in self._schema.values()])

def __iter__(self) -> Iterator[str]:
return iter(self._schema.keys())

def __len__(self) -> int:
return self.column_count

def __repr__(self) -> str:
return f"Schema({self!s})"

Expand Down Expand Up @@ -108,7 +120,7 @@ def column_names(self) -> list[str]:

def get_column_type(self, name: str) -> ColumnType:
"""
Get the type of a column.
Get the type of a column. This is equivalent to using the `[]` operator (indexed access).

Parameters
----------
Expand All @@ -131,14 +143,17 @@ def get_column_type(self, name: str) -> ColumnType:
>>> schema = Schema({"a": ColumnType.int64(), "b": ColumnType.float32()})
>>> schema.get_column_type("a")
int64

>>> schema["b"]
float32
"""
_check_columns_exist(self, name)

return _PolarsColumnType(self._schema[name])

def has_column(self, name: str) -> bool:
"""
Check if the table has a column with a specific name.
Check if the schema has a column with a specific name. This is equivalent to using the `in` operator.

Parameters
----------
Expand All @@ -148,7 +163,7 @@ def has_column(self, name: str) -> bool:
Returns
-------
has_column:
Whether the table has a column with the specified name.
Whether the schema has a column with the specified name.

Examples
--------
Expand All @@ -157,7 +172,7 @@ def has_column(self, name: str) -> bool:
>>> schema.has_column("a")
True

>>> schema.has_column("c")
>>> "c" in schema
False
"""
return name in self._schema
Expand Down
65 changes: 46 additions & 19 deletions tests/helpers/_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from polars.testing import assert_frame_equal

from safeds.data.labeled.containers import TabularDataset
from safeds.data.tabular.containers import Cell, Column, Table
from safeds.data.tabular.containers import Cell, Column, Row, Table


def assert_tables_are_equal(
Expand Down Expand Up @@ -62,44 +62,71 @@ def assert_that_tabular_datasets_are_equal(table1: TabularDataset, table2: Tabul


def assert_cell_operation_works(
input_value: Any,
value: Any,
transformer: Callable[[Cell], Cell],
expected_value: Any,
expected: Any,
) -> None:
"""
Assert that a cell operation works as expected.
Parameters
----------
input_value:
value:
The value in the input cell.
transformer:
The transformer to apply to the cells.
expected_value:
expected:
The expected value of the transformed cell.
"""
column = Column("A", [input_value])
column = Column("A", [value])
transformed_column = column.transform(transformer)
assert transformed_column == Column("A", [expected_value]), f"Expected: {expected_value}\nGot: {transformed_column}"
actual = transformed_column[0]
assert actual == expected


def assert_row_operation_works(
input_value: Any,
transformer: Callable[[Table], Table],
expected_value: Any,
table: Table,
computer: Callable[[Row], Cell],
expected: list[Any],
) -> None:
"""
Assert that a row operation works as expected.
Parameters
----------
input_value:
The value in the input row.
transformer:
The transformer to apply to the rows.
expected_value:
The expected value of the transformed row.
table:
The input table.
computer:
The function that computes the new column.
expected:
The expected values of the computed column.
"""
column_name = _find_free_column_name(table, "computed")

new_table = table.add_computed_column(column_name, computer)
actual = list(new_table.get_column(column_name))
assert actual == expected


def _find_free_column_name(table: Table, prefix: str) -> str:
"""
table = Table(input_value)
transformed_table = transformer(table)
assert transformed_table == Table(expected_value), f"Expected: {expected_value}\nGot: {transformed_table}"
Find a free column name in the table.
Parameters
----------
table:
The table to search for a free column name.
prefix:
The prefix to use for the column name.
Returns
-------
free_name:
A free column name.
"""
column_name = prefix

while column_name in table.column_names:
column_name += "_"

return column_name
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# serializer version: 1
# name: TestContract.test_should_return_same_hash_in_different_processes[empty]
1789859531466043636
# ---
# name: TestContract.test_should_return_same_hash_in_different_processes[no rows]
585695607399955642
# ---
# name: TestContract.test_should_return_same_hash_in_different_processes[with data]
909875695937937648
# ---
Loading
Loading