diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index c49a3ce6aea4e..bfd30dae5a34f 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -1163,6 +1163,9 @@ def get_n_splits(self, X=None, y=None, groups=None): **self.cvargs) return cv.get_n_splits(X, y, groups) * self.n_repeats + def __repr__(self): + return _build_repr(self) + class RepeatedKFold(_RepeatedSplits): """Repeated K-Fold cross validator. diff --git a/sklearn/model_selection/tests/test_repeated_split_repr.py b/sklearn/model_selection/tests/test_repeated_split_repr.py new file mode 100644 index 0000000000000..5ced8d93b9bc1 --- /dev/null +++ b/sklearn/model_selection/tests/test_repeated_split_repr.py @@ -0,0 +1,37 @@ +from sklearn import config_context +from sklearn.model_selection import ( + RepeatedKFold, + RepeatedStratifiedKFold, +) + + +def test_repeated_kfold_repr_default(): + with config_context(print_changed_only=False): + assert ( + repr(RepeatedKFold()) + == "RepeatedKFold(n_splits=5, n_repeats=10, random_state=None)" + ) + + +def test_repeated_stratified_kfold_repr_default(): + with config_context(print_changed_only=False): + assert ( + repr(RepeatedStratifiedKFold()) + == "RepeatedStratifiedKFold(n_splits=5, n_repeats=10, random_state=None)" + ) + + +def test_repeated_kfold_repr_nondefault(): + with config_context(print_changed_only=False): + assert ( + repr(RepeatedKFold(n_splits=3, n_repeats=2, random_state=42)) + == "RepeatedKFold(n_splits=3, n_repeats=2, random_state=42)" + ) + + +def test_repeated_stratified_kfold_repr_nondefault(): + with config_context(print_changed_only=False): + assert ( + repr(RepeatedStratifiedKFold(n_splits=4, n_repeats=3, random_state=0)) + == "RepeatedStratifiedKFold(n_splits=4, n_repeats=3, random_state=0)" + )