Skip to content

Commit

Permalink
Add split visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
labdmitriy committed Nov 20, 2021
1 parent 7eb7f25 commit c744b5c
Show file tree
Hide file tree
Showing 4 changed files with 434 additions and 87 deletions.
151 changes: 71 additions & 80 deletions ml_lab/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,6 @@ def __init__(
window (str):
Type of the window. Defaults to 'rolling'.
"""
# Check if only one of train_size and n_splits is specified
if (train_size is None) and (n_splits is None):
raise ValueError('Either train_size or n_splits have to be defined')

# Check if window mode value is correctly specified
if window not in ['rolling', 'expanding']:
raise ValueError('Window can be either "rolling" or "expanding"')

# Check if train size if specified only with rolling window
if (train_size is not None) and (window == 'expanding'):
raise ValueError('Train size can be specified only with rolling window')

self.test_size = test_size
self.train_size = train_size
self.n_splits = n_splits
Expand All @@ -69,9 +57,11 @@ def split(self,
Yields:
Iterator[Tuple[np.ndarray, np.ndarray]]: Train/test dataset indices.
"""
train_size = self.train_size
self._check_split_params()

# train_size = self.train_size
test_size = self.test_size
n_splits = self.n_splits
# n_splits = self.n_splits
gap = self.gap
shift_size = self.shift_size

Expand All @@ -82,79 +72,24 @@ def split(self,
if groups is None:
raise ValueError('Groups must be specified')

# Check if groups are consecutive in dataset
# Check if groups are sorted in dataset
group_seqs = [group[0] for group in groupby(groups)]
unique_groups, group_starts_idx = np.unique(groups, return_index=True)
n_groups = _num_samples(unique_groups)
self.n_groups = n_groups

if group_seqs > sorted(unique_groups):
if group_seqs != sorted(unique_groups):
raise ValueError('Groups must be presorted in increasing order')

# Calculate number of splits if not specified
if n_splits is None:
n_splits = ((n_groups - train_size - gap - test_size) // shift_size) + 1
self.n_splits = n_splits

# Check if number of splits is positive
if n_splits <= 0:
raise ValueError(
(
f'Not enough data to split number of groups ({n_groups})'
f' for number splits ({n_splits})'
f' with train size ({train_size}),'
f' test size ({test_size}), gap ({gap}), shift_size ({shift_size})'
)
)
train_start_idx = n_groups - train_size - gap - test_size - (n_splits - 1) * shift_size
# Calculate train size if not specified
elif train_size is None:
train_size = int(n_groups - gap - (test_size + (n_splits - 1) * shift_size))
self.train_size = train_size

# Check if train size is not empty
if train_size <= 0:
raise ValueError(
(
f'Not enough data to split number of groups ({n_groups})'
f' for number splits ({n_splits})'
f' with train size ({train_size}),'
f' test size ({test_size}), gap ({gap}), shift_size ({shift_size})'
)
)
train_start_idx = 0
# Calculate train start index if train size and number of splits ar specified
else:
train_start_idx = n_groups - train_size - gap - test_size - (n_splits - 1) * shift_size

# Check if data is enough to split
if train_start_idx < 0:
raise ValueError(
(
f'Not enough data to split number of groups ({n_groups})'
f' for number splits ({n_splits})'
f' with train size ({train_size}),'
f' test size ({test_size}), gap ({gap}), shift_size ({shift_size})'
)
)

# Calculate number of folds
n_folds = n_splits + 1

# Check if number of required folds is greater than groups
if n_folds > n_groups:
raise ValueError(
(
f'Cannot have number of folds ({n_folds}) greater than'
f' the number of groups ({n_groups})'
)
)

# Create mapping between groups and its start indices in array
groups_dict = dict(zip(unique_groups, group_starts_idx))

# Calculate number of samples
n_samples = _num_samples(X)

# Calculate remaining split params
train_size, n_splits, train_start_idx = self._calculate_split_params()

# Calculate start/end indices for initial train/test datasets
train_end_idx = train_start_idx + train_size
test_start_idx = train_end_idx + gap
Expand Down Expand Up @@ -186,15 +121,12 @@ def split(self,
test_end_idx = test_end_idx + shift_size

def get_n_splits(
self,
X: Optional[Iterable] = None,
y: Optional[Iterable] = None,
groups: Optional[Iterable] = None
self, X: Iterable, y: Optional[Iterable] = None, groups: Optional[Iterable] = None
) -> int:
"""Calculates number of splits given specified parameters.
Args:
X (Optional[Iterable], optional): Dataset with features. Defaults to None.
X (Iterable): Dataset with features. Defaults to None.
y (Optional[Iterable], optional): Dataset with target. Defaults to None.
groups (Optional[Iterable], optional): Array with group numbers. Defaults to None.
Expand All @@ -205,3 +137,62 @@ def get_n_splits(
return self.n_splits
else:
raise ValueError('Number of splits is not defined')

def _check_split_params(self):
if (self.train_size is None) and (self.n_splits is None):
raise ValueError('Either train_size or n_splits have to be defined')

if self.window not in ['rolling', 'expanding']:
raise ValueError('Window can be either "rolling" or "expanding"')

if (self.train_size is not None) and (self.window == 'expanding'):
raise ValueError('Train size can be specified only with rolling window')

def _calculate_split_params(self):
train_size = self.train_size
test_size = self.test_size
n_splits = self.n_splits
gap = self.gap
shift_size = self.shift_size
n_groups = self.n_groups

not_enough_data_error = (
'Not enough data to split number of groups ({0})'
' for number splits ({1})'
' with train size ({2}),'
' test size ({3}), gap ({4}), shift_size ({5})'
)

if train_size is None:
train_size = int(n_groups - gap - (test_size + (n_splits - 1) * shift_size))
self.train_size = train_size

if train_size <= 0:
raise ValueError(
not_enough_data_error.format(
n_groups, n_splits, train_size, test_size, gap, shift_size
)
)
train_start_idx = 0
elif self.n_splits is None:
n_splits = ((n_groups - train_size - gap - test_size) // shift_size) + 1
self.n_splits = n_splits

if n_splits <= 0:
raise ValueError(
not_enough_data_error.format(
n_groups, n_splits, train_size, test_size, gap, shift_size
)
)
train_start_idx = n_groups - train_size - gap - test_size - (n_splits - 1) * shift_size
else:
train_start_idx = n_groups - train_size - gap - test_size - (n_splits - 1) * shift_size

if train_start_idx < 0:
raise ValueError(
not_enough_data_error.format(
n_groups, n_splits, train_size, test_size, gap, shift_size
)
)

return train_size, n_splits, train_start_idx
Loading

0 comments on commit c744b5c

Please sign in to comment.