diff --git a/src/actinet/models.py b/src/actinet/models.py index 1f88f0b..263fc68 100644 --- a/src/actinet/models.py +++ b/src/actinet/models.py @@ -76,7 +76,7 @@ def fit( t_splits = [] if n_splits < 3: - splitter = GroupShuffleSplit(n_splits=n_splits) + splitter = GroupShuffleSplit(n_splits=n_splits, random_state=42) split_iterator = splitter.split(X, Y, groups) else: splitter = StratifiedGroupKFold(n_splits)