@@ -79,26 +79,22 @@ def interpolate(
79
79
80
80
# identify transitions between segments
81
81
seg_trans_col = "__tmp_seg_transition"
82
- all_win = Window . partitionBy ( "symbol" ). orderBy ( "timestamp" )
82
+ all_win = tsdf . baseWindow ( )
83
83
segments = segments .withColumn (seg_trans_col ,
84
84
sfn .lag (needs_intpl_col , 1 ).over (all_win )
85
85
!= sfn .col (needs_intpl_col ))
86
86
87
87
# assign a group number to each segment
88
88
seg_group_col = "__tmp_seg_group"
89
- all_prev_win = Window .partitionBy ("symbol" )\
90
- .orderBy ("timestamp" )\
91
- .rowsBetween (Window .unboundedPreceding , Window .currentRow )
89
+ all_prev_win = tsdf .allBeforeWindow ()
92
90
segments = segments .withColumn (seg_group_col ,
93
91
sfn .count_if (seg_trans_col ).over (all_prev_win ))
94
92
95
93
# build margins around intepolation segments
96
94
if leading_margin > 0 or lagging_margin > 0 :
97
95
# collect the group number of each segment with a margin
98
96
margin_col = "__tmp_group_with_margin"
99
- margin_win = Window .partitionBy ("symbol" )\
100
- .orderBy ("timestamp" )\
101
- .rowsBetween (- leading_margin , lagging_margin )
97
+ margin_win = tsdf .rowsBetweenWindow (- leading_margin , lagging_margin )
102
98
segments = segments .withColumn (margin_col ,
103
99
sfn .when (~ sfn .col (needs_intpl_col ),
104
100
sfn .collect_set (seg_group_col )
0 commit comments