Skip to content

Commit 11f51cd

Browse files
author
Tristan Nixon
committed
applying black formatting
1 parent e17ffce commit 11f51cd

File tree

1 file changed

+36
-20
lines changed

1 file changed

+36
-20
lines changed

python/tempo/ml.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,35 @@
1212
TMP_SPLIT_COL = "__tmp_split_col"
1313
TMP_GAP_COL = "__tmp_gap_row"
1414

15+
1516
class TimeSeriesCrossValidator(CrossValidator):
1617
# some additional parameters
1718
timeSeriesCol: Param[str] = Param(
1819
Params._dummy(),
1920
"timeSeriesCol",
2021
"The name of the time series column",
21-
typeConverter=TypeConverters.toString
22+
typeConverter=TypeConverters.toString,
2223
)
2324
seriesIdCols: Param[List[str]] = Param(
2425
Params._dummy(),
2526
"seriesIdCols",
2627
"The name of the series id columns",
27-
typeConverter=TypeConverters.toListString
28+
typeConverter=TypeConverters.toListString,
2829
)
2930
gap: Param[int] = Param(
3031
Params._dummy(),
3132
"gap",
3233
"The gap between training and test set",
33-
typeConverter=TypeConverters.toInt
34+
typeConverter=TypeConverters.toInt,
3435
)
3536

36-
def __init__(self,
37-
timeSeriesCol: str = "event_ts",
38-
seriesIdCols: List[str] = [],
39-
gap: int = 0,
40-
**other_kwargs) -> None:
37+
def __init__(
38+
self,
39+
timeSeriesCol: str = "event_ts",
40+
seriesIdCols: List[str] = [],
41+
gap: int = 0,
42+
**other_kwargs
43+
) -> None:
4144
super(TimeSeriesCrossValidator, self).__init__(**other_kwargs)
4245
self._setDefault(timeSeriesCol="event_ts", seriesIdCols=[], gap=0)
4346
self._set(timeSeriesCol=timeSeriesCol, seriesIdCols=seriesIdCols, gap=gap)
@@ -72,19 +75,24 @@ def _get_split_win(self, desc: bool = False) -> WindowSpec:
7275

7376
def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]:
7477
nFolds = self.getOrDefault(self.numFolds)
75-
nSplits = nFolds+1
78+
nSplits = nFolds + 1
7679

7780
# split the data into nSplits subsets by timeseries order
78-
split_df = dataset.withColumn(TMP_SPLIT_COL,
79-
sfn.ntile(nSplits).over(self._get_split_win()))
80-
all_splits = [split_df.filter(sfn.col(TMP_SPLIT_COL) == i).drop(TMP_SPLIT_COL)
81-
for i in range(1, nSplits+1)]
81+
split_df = dataset.withColumn(
82+
TMP_SPLIT_COL, sfn.ntile(nSplits).over(self._get_split_win())
83+
)
84+
all_splits = [
85+
split_df.filter(sfn.col(TMP_SPLIT_COL) == i).drop(TMP_SPLIT_COL)
86+
for i in range(1, nSplits + 1)
87+
]
8288
assert len(all_splits) == nSplits
8389

8490
# compose the k folds by including all previous splits in the training set,
8591
# and the next split in the test set
86-
kFolds = [(reduce(lambda a, b: a.union(b), all_splits[:i+1]), all_splits[i+1])
87-
for i in range(nFolds)]
92+
kFolds = [
93+
(reduce(lambda a, b: a.union(b), all_splits[: i + 1]), all_splits[i + 1])
94+
for i in range(nFolds)
95+
]
8896
assert len(kFolds) == nFolds
8997
for tv in kFolds:
9098
assert len(tv) == 2
@@ -94,13 +102,21 @@ def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]:
94102
if gap > 0:
95103
order_cols = self.getSeriesIdCols() + [self.getTimeSeriesCol()]
96104
# trim each training dataset by the specified gap
97-
kFolds = [((train_df.withColumn(TMP_GAP_COL,
98-
sfn.row_number().over(self._get_split_win(desc=True)))
105+
kFolds = [
106+
(
107+
(
108+
train_df.withColumn(
109+
TMP_GAP_COL,
110+
sfn.row_number().over(self._get_split_win(desc=True)),
111+
)
99112
.where(sfn.col(TMP_GAP_COL) > gap)
100113
.drop(TMP_GAP_COL)
101-
.orderBy(*order_cols)),
102-
test_df)
103-
for (train_df, test_df) in kFolds]
114+
.orderBy(*order_cols)
115+
),
116+
test_df,
117+
)
118+
for (train_df, test_df) in kFolds
119+
]
104120

105121
# return the k folds (training, test) datasets
106122
return kFolds

0 commit comments

Comments
 (0)