diff --git a/src/safeds/ml/classical/classification/_logistic_classifier.py b/src/safeds/ml/classical/classification/_logistic_classifier.py index e312e6b25..e00a8d399 100644 --- a/src/safeds/ml/classical/classification/_logistic_classifier.py +++ b/src/safeds/ml/classical/classification/_logistic_classifier.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from safeds._utils import _get_random_seed, _structural_hash +from safeds._validation import _check_bounds, _OpenBound from ._classifier import Classifier @@ -11,26 +12,48 @@ class LogisticClassifier(Classifier): - """Regularized logistic regression for classification.""" + """ + Regularized logistic regression for classification. + + Parameters + ---------- + c: + The regularization strength. Lower values imply stronger regularization. Must be greater than 0. + """ # ------------------------------------------------------------------------------------------------------------------ # Dunder methods # ------------------------------------------------------------------------------------------------------------------ - def __init__(self) -> None: + def __init__(self, *, c: float = 1.0) -> None: super().__init__() + # Validation + _check_bounds("c", c, lower_bound=_OpenBound(0)) + + # Hyperparameters + self._c: float = c + def __hash__(self) -> int: return _structural_hash( super().__hash__(), ) + # ------------------------------------------------------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------------------------------------------------------ + + @property + def c(self) -> float: + """The regularization strength. Lower values imply stronger regularization.""" + return self._c + # ------------------------------------------------------------------------------------------------------------------ # Template methods # ------------------------------------------------------------------------------------------------------------------ def _clone(self) -> LogisticClassifier: - return LogisticClassifier() + return LogisticClassifier(c=self.c) def _get_sklearn_model(self) -> ClassifierMixin: from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression @@ -38,4 +61,5 @@ def _get_sklearn_model(self) -> ClassifierMixin: return SklearnLogisticRegression( random_state=_get_random_seed(), n_jobs=-1, + C=self.c, ) diff --git a/tests/safeds/ml/classical/classification/test_logistic_classifier.py b/tests/safeds/ml/classical/classification/test_logistic_classifier.py new file mode 100644 index 000000000..92bd4b8da --- /dev/null +++ b/tests/safeds/ml/classical/classification/test_logistic_classifier.py @@ -0,0 +1,33 @@ +import pytest +from safeds.data.labeled.containers import TabularDataset +from safeds.data.tabular.containers import Table +from safeds.exceptions import OutOfBoundsError +from safeds.ml.classical.classification import LogisticClassifier + + +@pytest.fixture() +def training_set() -> TabularDataset: + table = Table({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]}) + return table.to_tabular_dataset(target_name="col1") + + +class TestC: + def test_should_be_passed_to_fitted_model(self, training_set: TabularDataset) -> None: + fitted_model = LogisticClassifier(c=2).fit(training_set) + assert fitted_model.c == 2 + + def test_should_be_passed_to_sklearn(self, training_set: TabularDataset) -> None: + fitted_model = LogisticClassifier(c=2).fit(training_set) + assert fitted_model._wrapped_model is not None + assert fitted_model._wrapped_model.C == 2 + + def test_clone(self, training_set: TabularDataset) -> None: + fitted_model = LogisticClassifier(c=2).fit(training_set) + cloned_classifier = fitted_model._clone() + assert isinstance(cloned_classifier, LogisticClassifier) + assert cloned_classifier.c == fitted_model.c + + @pytest.mark.parametrize("c", [-1.0, 0.0], ids=["minus_one", "zero"]) + def test_should_raise_if_less_than_or_equal_to_0(self, c: float) -> None: + with pytest.raises(OutOfBoundsError): + LogisticClassifier(c=c)