Skip to content

Commit

Permalink
fix: add missing Optional type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
IamGianluca committed Mar 24, 2024
1 parent 529d97b commit dbf342d
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions blazingai/tabular/model_selection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Any, Dict, Generator, List, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -53,29 +53,41 @@ def validate_method(method):
else:
return method

def split(self, X: pd.DataFrame, y: np.ndarray = None, groups: str = None):
def split(
self,
X: pd.DataFrame,
y: Optional[np.ndarray] = None,
groups: Optional[str] = None,
) -> Generator[
tuple[
np.ndarray,
np.ndarray,
],
None,
None,
]:
"""
Args:
X (pandas.DataFrame): Input data.
method (str): Either `sliding` or `expanding`.
y: to make the function compatible with sklearn cross validation.
groups: to make the function compatible with sklearn cross validation.
y: Unused, but required for compatibility with sklearn cross validation API.
groups: Unused, but required for compatibility with sklearn cross validation API.
Returns:
(generator) Indexes for the training and Validation set from the
data passed
"""
self._is_valid_df(X)
ar = self._df_to_array(X)
self._has_enough_days(ar)
first = ar.min()
arr = self._df_to_array(X)
self._has_enough_days(arr)
first = arr.min()

# yield indexes for train and valid sets
for step in range(self.n_splits):
train_start, train_end, valid_start, valid_end = self._get_dates(
first=first, step=step
)
train_mask = (ar >= train_start) & (ar <= train_end)
valid_mask = (ar >= valid_start) & (ar <= valid_end)
train_mask = (arr >= train_start) & (arr <= train_end)
valid_mask = (arr >= valid_start) & (arr <= valid_end)

self.train_date_ranges["period_" + str(step)] = [
train_start.date().isoformat(),
Expand Down

0 comments on commit dbf342d

Please sign in to comment.