Skip to content

Commit 62320a3

Browse files
TarmandanwastedareasSimonmegalinter-bot
authored
feat: add RobustScaler (#874)
Closes #650 ### Summary of Changes Adds a RobustScaler class that works like the StandardScaler but uses median instead of mean and interquartile range instead of standard deviation. If the interquartile range is 0 it will only substract the median from all rows. For now cannot handle columns containing NaN-values. See Issue #873 --------- Co-authored-by: srose <118634249+wastedareas@users.noreply.github.com> Co-authored-by: Simon <simon@schwubbel.dip0.t-ipconnect.de> Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
1 parent 9fd888d commit 62320a3

File tree

3 files changed

+438
-0
lines changed

3 files changed

+438
-0
lines changed

src/safeds/data/tabular/transformation/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ._label_encoder import LabelEncoder
1111
from ._one_hot_encoder import OneHotEncoder
1212
from ._range_scaler import RangeScaler
13+
from ._robust_scaler import RobustScaler
1314
from ._simple_imputer import SimpleImputer
1415
from ._standard_scaler import StandardScaler
1516
from ._table_transformer import TableTransformer
@@ -22,6 +23,7 @@
2223
"LabelEncoder": "._label_encoder:LabelEncoder",
2324
"OneHotEncoder": "._one_hot_encoder:OneHotEncoder",
2425
"RangeScaler": "._range_scaler:RangeScaler",
26+
"RobustScaler": "._robust_scaler:RobustScaler",
2527
"SimpleImputer": "._simple_imputer:SimpleImputer",
2628
"StandardScaler": "._standard_scaler:StandardScaler",
2729
"TableTransformer": "._table_transformer:TableTransformer",
@@ -34,6 +36,7 @@
3436
"LabelEncoder",
3537
"OneHotEncoder",
3638
"RangeScaler",
39+
"RobustScaler",
3740
"SimpleImputer",
3841
"StandardScaler",
3942
"TableTransformer",
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from safeds._validation import _check_columns_exist
6+
from safeds._validation._check_columns_are_numeric import _check_columns_are_numeric
7+
from safeds.data.tabular.containers import Table
8+
from safeds.exceptions import TransformerNotFittedError
9+
10+
from ._invertible_table_transformer import InvertibleTableTransformer
11+
12+
if TYPE_CHECKING:
13+
import polars as pl
14+
15+
16+
class RobustScaler(InvertibleTableTransformer):
17+
"""
18+
The RobustScaler transforms column values to a range by removing the median and scaling to the interquartile range.
19+
20+
Currently, for columns with high stability (IQR == 0), it will only substract the median and not scale to avoid dividing by zero.
21+
22+
Parameters
23+
----------
24+
column_names:
25+
The list of columns used to fit the transformer. If `None`, all numeric columns are used.
26+
"""
27+
28+
# ------------------------------------------------------------------------------------------------------------------
29+
# Dunder methods
30+
# ------------------------------------------------------------------------------------------------------------------
31+
32+
def __init__(self, *, column_names: str | list[str] | None = None) -> None:
33+
super().__init__(column_names)
34+
35+
# Internal state
36+
self._data_median: pl.DataFrame | None = None
37+
self._data_scale: pl.DataFrame | None = None
38+
39+
def __hash__(self) -> int:
40+
# Leave out the internal state for faster hashing
41+
return super().__hash__()
42+
43+
# ------------------------------------------------------------------------------------------------------------------
44+
# Properties
45+
# ------------------------------------------------------------------------------------------------------------------
46+
47+
@property
48+
def is_fitted(self) -> bool:
49+
"""Whether the transformer is fitted."""
50+
return self._data_median is not None and self._data_scale is not None
51+
52+
# ------------------------------------------------------------------------------------------------------------------
53+
# Learning and transformation
54+
# ------------------------------------------------------------------------------------------------------------------
55+
56+
def fit(self, table: Table) -> RobustScaler:
57+
"""
58+
Learn a transformation for a set of columns in a table.
59+
60+
**Note:** This transformer is not modified.
61+
62+
Parameters
63+
----------
64+
table:
65+
The table used to fit the transformer.
66+
67+
Returns
68+
-------
69+
fitted_transformer:
70+
The fitted transformer.
71+
72+
Raises
73+
------
74+
ColumnNotFoundError
75+
If column_names contain a column name that is missing in the table.
76+
ColumnTypeError
77+
If at least one of the specified columns in the table contains non-numerical data.
78+
ValueError
79+
If the table contains 0 rows.
80+
"""
81+
import polars as pl
82+
83+
if self._column_names is None:
84+
column_names = [name for name in table.column_names if table.get_column_type(name).is_numeric]
85+
else:
86+
column_names = self._column_names
87+
_check_columns_exist(table, column_names)
88+
_check_columns_are_numeric(table, column_names, operation="fit a RobustScaler")
89+
90+
if table.row_count == 0:
91+
raise ValueError("The RobustScaler cannot be fitted because the table contains 0 rows")
92+
93+
_data_median = table._lazy_frame.select(column_names).median().collect()
94+
q1 = table._lazy_frame.select(column_names).quantile(0.25).collect()
95+
q3 = table._lazy_frame.select(column_names).quantile(0.75).collect()
96+
_data_scale = q3 - q1
97+
98+
# To make sure there is no division by zero
99+
for col_e in column_names:
100+
_data_scale = _data_scale.with_columns(
101+
pl.when(pl.col(col_e) == 0).then(1).otherwise(pl.col(col_e)).alias(col_e),
102+
)
103+
104+
# Create a copy with the learned transformation
105+
result = RobustScaler(column_names=column_names)
106+
result._data_median = _data_median
107+
result._data_scale = _data_scale
108+
109+
return result
110+
111+
def transform(self, table: Table) -> Table:
112+
"""
113+
Apply the learned transformation to a table.
114+
115+
**Note:** The given table is not modified.
116+
117+
Parameters
118+
----------
119+
table:
120+
The table to which the learned transformation is applied.
121+
122+
Returns
123+
-------
124+
transformed_table:
125+
The transformed table.
126+
127+
Raises
128+
------
129+
TransformerNotFittedError
130+
If the transformer has not been fitted yet.
131+
ColumnNotFoundError
132+
If the input table does not contain all columns used to fit the transformer.
133+
ColumnTypeError
134+
If at least one of the columns in the input table that is used to fit contains non-numerical data.
135+
"""
136+
import polars as pl
137+
138+
# Used in favor of is_fitted, so the type checker is happy
139+
if self._column_names is None or self._data_median is None or self._data_scale is None:
140+
raise TransformerNotFittedError
141+
142+
_check_columns_exist(table, self._column_names)
143+
_check_columns_are_numeric(table, self._column_names, operation="transform with a RobustScaler")
144+
145+
columns = [
146+
(pl.col(name) - self._data_median.get_column(name)) / self._data_scale.get_column(name)
147+
for name in self._column_names
148+
]
149+
150+
return Table._from_polars_lazy_frame(
151+
table._lazy_frame.with_columns(columns),
152+
)
153+
154+
def inverse_transform(self, transformed_table: Table) -> Table:
155+
"""
156+
Undo the learned transformation.
157+
158+
**Note:** The given table is not modified.
159+
160+
Parameters
161+
----------
162+
transformed_table:
163+
The table to be transformed back to the original version.
164+
165+
Returns
166+
-------
167+
original_table:
168+
The original table.
169+
170+
Raises
171+
------
172+
TransformerNotFittedError
173+
If the transformer has not been fitted yet.
174+
ColumnNotFoundError
175+
If the input table does not contain all columns used to fit the transformer.
176+
ColumnTypeError
177+
If the transformed columns of the input table contain non-numerical data.
178+
"""
179+
import polars as pl
180+
181+
# Used in favor of is_fitted, so the type checker is happy
182+
if self._column_names is None or self._data_median is None or self._data_scale is None:
183+
raise TransformerNotFittedError
184+
185+
_check_columns_exist(transformed_table, self._column_names)
186+
_check_columns_are_numeric(
187+
transformed_table,
188+
self._column_names,
189+
operation="inverse-transform with a RobustScaler",
190+
)
191+
192+
columns = [
193+
pl.col(name) * self._data_scale.get_column(name) + self._data_median.get_column(name)
194+
for name in self._column_names
195+
]
196+
197+
return Table._from_polars_lazy_frame(
198+
transformed_table._lazy_frame.with_columns(columns),
199+
)

0 commit comments

Comments
 (0)