12
12
TMP_SPLIT_COL = "__tmp_split_col"
13
13
TMP_GAP_COL = "__tmp_gap_row"
14
14
15
+
15
16
class TimeSeriesCrossValidator (CrossValidator ):
16
17
# some additional parameters
17
18
timeSeriesCol : Param [str ] = Param (
18
19
Params ._dummy (),
19
20
"timeSeriesCol" ,
20
21
"The name of the time series column" ,
21
- typeConverter = TypeConverters .toString
22
+ typeConverter = TypeConverters .toString ,
22
23
)
23
24
seriesIdCols : Param [List [str ]] = Param (
24
25
Params ._dummy (),
25
26
"seriesIdCols" ,
26
27
"The name of the series id columns" ,
27
- typeConverter = TypeConverters .toListString
28
+ typeConverter = TypeConverters .toListString ,
28
29
)
29
30
gap : Param [int ] = Param (
30
31
Params ._dummy (),
31
32
"gap" ,
32
33
"The gap between training and test set" ,
33
- typeConverter = TypeConverters .toInt
34
+ typeConverter = TypeConverters .toInt ,
34
35
)
35
36
36
- def __init__ (self ,
37
- timeSeriesCol : str = "event_ts" ,
38
- seriesIdCols : List [str ] = [],
39
- gap : int = 0 ,
40
- ** other_kwargs ) -> None :
37
+ def __init__ (
38
+ self ,
39
+ timeSeriesCol : str = "event_ts" ,
40
+ seriesIdCols : List [str ] = [],
41
+ gap : int = 0 ,
42
+ ** other_kwargs
43
+ ) -> None :
41
44
super (TimeSeriesCrossValidator , self ).__init__ (** other_kwargs )
42
45
self ._setDefault (timeSeriesCol = "event_ts" , seriesIdCols = [], gap = 0 )
43
46
self ._set (timeSeriesCol = timeSeriesCol , seriesIdCols = seriesIdCols , gap = gap )
@@ -72,19 +75,24 @@ def _get_split_win(self, desc: bool = False) -> WindowSpec:
72
75
73
76
def _kFold (self , dataset : DataFrame ) -> List [Tuple [DataFrame , DataFrame ]]:
74
77
nFolds = self .getOrDefault (self .numFolds )
75
- nSplits = nFolds + 1
78
+ nSplits = nFolds + 1
76
79
77
80
# split the data into nSplits subsets by timeseries order
78
- split_df = dataset .withColumn (TMP_SPLIT_COL ,
79
- sfn .ntile (nSplits ).over (self ._get_split_win ()))
80
- all_splits = [split_df .filter (sfn .col (TMP_SPLIT_COL ) == i ).drop (TMP_SPLIT_COL )
81
- for i in range (1 , nSplits + 1 )]
81
+ split_df = dataset .withColumn (
82
+ TMP_SPLIT_COL , sfn .ntile (nSplits ).over (self ._get_split_win ())
83
+ )
84
+ all_splits = [
85
+ split_df .filter (sfn .col (TMP_SPLIT_COL ) == i ).drop (TMP_SPLIT_COL )
86
+ for i in range (1 , nSplits + 1 )
87
+ ]
82
88
assert len (all_splits ) == nSplits
83
89
84
90
# compose the k folds by including all previous splits in the training set,
85
91
# and the next split in the test set
86
- kFolds = [(reduce (lambda a , b : a .union (b ), all_splits [:i + 1 ]), all_splits [i + 1 ])
87
- for i in range (nFolds )]
92
+ kFolds = [
93
+ (reduce (lambda a , b : a .union (b ), all_splits [: i + 1 ]), all_splits [i + 1 ])
94
+ for i in range (nFolds )
95
+ ]
88
96
assert len (kFolds ) == nFolds
89
97
for tv in kFolds :
90
98
assert len (tv ) == 2
@@ -94,13 +102,21 @@ def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]:
94
102
if gap > 0 :
95
103
order_cols = self .getSeriesIdCols () + [self .getTimeSeriesCol ()]
96
104
# trim each training dataset by the specified gap
97
- kFolds = [((train_df .withColumn (TMP_GAP_COL ,
98
- sfn .row_number ().over (self ._get_split_win (desc = True )))
105
+ kFolds = [
106
+ (
107
+ (
108
+ train_df .withColumn (
109
+ TMP_GAP_COL ,
110
+ sfn .row_number ().over (self ._get_split_win (desc = True )),
111
+ )
99
112
.where (sfn .col (TMP_GAP_COL ) > gap )
100
113
.drop (TMP_GAP_COL )
101
- .orderBy (* order_cols )),
102
- test_df )
103
- for (train_df , test_df ) in kFolds ]
114
+ .orderBy (* order_cols )
115
+ ),
116
+ test_df ,
117
+ )
118
+ for (train_df , test_df ) in kFolds
119
+ ]
104
120
105
121
# return the k folds (training, test) datasets
106
122
return kFolds
0 commit comments