Skip to content

Commit

Permalink
WIP checkpoint commit - refactoring AsOfJoin logic into external fact…
Browse files Browse the repository at this point in the history
…ory classes
  • Loading branch information
tnixon committed Sep 19, 2023
1 parent 1c84989 commit a5cb399
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 44 deletions.
29 changes: 24 additions & 5 deletions python/tempo/as_of_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

# As-of join types


joiner = MyAsOfJoinerImpl()
joiner(left_tsdf, right_tsdf)

class AsOfJoiner(ABC):
"""
Abstract class for as-of join strategies
Expand Down Expand Up @@ -114,14 +118,16 @@ class BroadcastAsOfJoiner(AsOfJoiner):
def __init__(self,
spark: SparkSession,
left_prefix: str = "left",
right_prefix: str = "right"):
right_prefix: str = "right",
range_join_bin_size: int = 60):
super().__init__(left_prefix, right_prefix)
self.spark = spark
self.range_join_bin_size = range_join_bin_size

def _join(self, left: t_tsdf.TSDF, right: t_tsdf.TSDF) -> t_tsdf.TSDF:
# set the range join bin size to 60 seconds
self.spark.conf.set("spark.databricks.optimizer.rangeJoin.binSize",
"60")
str(self.range_join_bin_size))
# find a leading column in the right TSDF
w = right.baseWindow()
lead_colname = "lead_" + right.ts_col
Expand Down Expand Up @@ -199,8 +205,17 @@ def _combine(self, left: t_tsdf.TSDF, right: t_tsdf.TSDF) -> t_tsdf.TSDF:
combined_ts = unioned.withColumn(_DEFAULT_COMBINED_TS_COLNAME,
sfn.coalesce(left.ts_col,
right.ts_col))
# add a column to indicate if the row is from the right-hand side
return self._addRightRowIndicator(combined_ts)
# add indicator column
r_row_ind = combined_ts.withColumn(_DEFAULT_RIGHT_ROW_COLNAME,
sfn.when(
sfn.col(left.ts_col).isNotNull(),1)
.otherwise(-1))
# combine all the ts columns into a single struct column
return t_tsdf.TSDF.makeCompositeIndexTSDF(r_row_ind.df,
[_DEFAULT_COMBINED_TS_COLNAME,
_DEFAULT_RIGHT_ROW_COLNAME],
combined_ts.series_ids,
[left.ts_col, right.ts_col])

def _filterLastRightRow(self,
combined: t_tsdf.TSDF,
Expand All @@ -210,14 +225,18 @@ def _filterLastRightRow(self,
"""
# find the last value for each column in the right-hand side
w = combined.allBeforeWindow()
filtered = reduce(
last_right_vals = reduce(
lambda cur_tsdf, col: cur_tsdf.withColumn(
col,
sfn.last(col, self.skipNulls).over(w),
),
right_cols,
combined,
)
# filter out the last right-hand row for each left-hand row
r_row_ind: str = combined.ts_index.component(_DEFAULT_RIGHT_ROW_COLNAME)
filtered = last_right_vals.where(sfn.col(r_row_ind) == 1).drop(r_row_ind)
return filtered

def _toleranceFilter(self, as_of: t_tsdf.TSDF) -> t_tsdf.TSDF:
"""
Expand Down
63 changes: 37 additions & 26 deletions python/tempo/tsdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import DataType, StructType
from pyspark.sql._typing import ColumnOrName
from pyspark.sql.window import Window, WindowSpec

import tempo.interpol as t_interpolation
Expand All @@ -27,6 +28,24 @@

logger = logging.getLogger(__name__)

# default column name for constructed timeseries index struct columns
DEFAULT_TS_IDX_COL = "ts_idx"

def makeStructFromCols(df: DataFrame,
struct_col_name: str,
cols_to_move: Collection[str]) -> DataFrame:
"""
Transform a :class:`DataFrame` by moving certain columns into a struct
:param df: the :class:`DataFrame` to transform
:param struct_col_name: name of the struct column to create
:param cols_to_move: name of the columns to move into the struct
:return: the transformed :class:`DataFrame`
"""
return (df.withColumn(struct_col_name,
sfn.struct(*cols_to_move))
.drop(*cols_to_move))

class TSDF(WindowBuilder):
"""
Expand Down Expand Up @@ -83,23 +102,24 @@ def __withStandardizedColOrder(self) -> TSDF:
return self.__withTransformedDF(self.df.select(std_ordered_cols))

@classmethod
def __makeStructFromCols(
cls, df: DataFrame, struct_col_name: str, cols_to_move: List[str]
) -> DataFrame:
def makeCompositeIndexTSDF(
cls,
df: DataFrame,
ts_index_cols: Collection[str],
series_ids: Collection[str] = None,
other_index_cols: Collection[str] = None,
composite_index_name: str = DEFAULT_TS_IDX_COL) -> "TSDF":
"""
Transform a :class:`DataFrame` by moving certain columns into a struct
:param df: the :class:`DataFrame` to transform
:param struct_col_name: name of the struct column to create
:param cols_to_move: name of the columns to move into the struct
:return: the transformed :class:`DataFrame`
Construct a TSDF with a composite index from the specified columns
"""
return (df.withColumn(struct_col_name, sfn.struct(*cols_to_move))
.drop(*cols_to_move))

# default column name for constructed timeseries index struct columns
__DEFAULT_TS_IDX_COL = "ts_idx"
# move all the index columns into a struct
all_composite_cols = set(ts_index_cols).union(set(other_index_cols or []))
with_struct_df = makeStructFromCols(df, composite_index_name, all_composite_cols)
# construct an appropriate TSIndex
ts_idx_struct = with_struct_df.schema[DEFAULT_TS_IDX_COL]
comp_idx = CompositeTSIndex(ts_idx_struct, *ts_index_cols)
# construct & return the TSDF with appropriate schema
return TSDF(with_struct_df, ts_schema=TSSchema(comp_idx, series_ids))

@classmethod
def fromSubsequenceCol(
Expand All @@ -109,16 +129,7 @@ def fromSubsequenceCol(
subsequence_col: str,
series_ids: Collection[str] = None,
) -> "TSDF":
# construct a struct with the ts_col and subsequence_col
struct_col_name = cls.__DEFAULT_TS_IDX_COL
with_subseq_struct_df = cls.__makeStructFromCols(
df, struct_col_name, [ts_col, subsequence_col]
)
# construct an appropriate TSIndex
subseq_struct = with_subseq_struct_df.schema[struct_col_name]
subseq_idx = CompositeTSIndex(subseq_struct, ts_col, subsequence_col)
# construct & return the TSDF with appropriate schema
return TSDF(with_subseq_struct_df, ts_schema=TSSchema(subseq_idx, series_ids))
return cls.makeCompositeIndexTSDF(df, [ts_col, subsequence_col], series_ids)

@classmethod
def fromTimestampString(
Expand Down Expand Up @@ -396,7 +407,7 @@ def select(self, *cols: Union[str, Column]) -> TSDF:
selected_df = self.df.select(*cols)
return self.__withTransformedDF(selected_df)

def where(self, condition: Union[Column, str]) -> "TSDF":
def where(self, condition: ColumnOrName) -> "TSDF":
"""
Selects rows using the given condition.
Expand Down
26 changes: 13 additions & 13 deletions python/tempo/tsschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,6 @@ def _indexAttributes(self) -> dict[str, Any]:
def colname(self):
return self.__name

@property
def ts_col(self) -> str:
return self.colname

def validate(self, df_schema: StructType) -> None:
# the ts column must exist
assert(self.colname in df_schema.fieldNames(),
Expand Down Expand Up @@ -189,7 +185,8 @@ def fromTSCol(cls, ts_col: StructField) -> "SimpleTSIndex":
return SimpleDateIndex(ts_col)
else:
raise TypeError(
f"A SimpleTSIndex must be a Numeric, Timestamp or Date type, but column {ts_col.name} is of type {ts_col.dataType}"
f"A SimpleTSIndex must be a Numeric, Timestamp or Date type, "
f"but column {ts_col.name} is of type {ts_col.dataType}"
)


Expand Down Expand Up @@ -299,10 +296,6 @@ def _indexAttributes(self) -> dict[str, Any]:
def colname(self) -> str:
return self.__name

@property
def ts_col(self) -> str:
return self.primary_ts_col

@property
def primary_ts_col(self) -> str:
return self.ts_component(0)
Expand Down Expand Up @@ -372,7 +365,9 @@ def __init__(
src_str_field = self.struct[src_str_col]
if not isinstance(src_str_field.dataType, StringType):
raise TypeError(
f"Source string column must be of StringType, but given column {src_str_field.name} is of type {src_str_field.dataType}"
f"Source string column must be of StringType, "
f"but given column {src_str_field.name} "
f"is of type {src_str_field.dataType}"
)
self.__src_str_col = src_str_col

Expand All @@ -389,7 +384,8 @@ def src_str_col(self):
def validate(self, df_schema: StructType) -> None:
super().validate(df_schema)
# make sure the parsed field exists
composite_idx_type: StructType = cast(StructType, df_schema[self.colname].dataType)
composite_idx_type: StructType = cast(StructType,
df_schema[self.colname].dataType)
assert(self.__src_str_col in composite_idx_type,
f"The src_str_col column {self.src_str_col} "
f"does not exist in the composite field {composite_idx_type}")
Expand All @@ -411,7 +407,9 @@ def __init__(
super().__init__(ts_idx, src_str_col, parsed_col)
if not isinstance(self.primary_ts_idx.dataType, TimestampType):
raise TypeError(
f"ParsedTimestampIndex must be of TimestampType, but given ts_col {self.primary_ts_idx.colname} has type {self.primary_ts_idx.dataType}"
f"ParsedTimestampIndex must be of TimestampType, "
f"but given ts_col {self.primary_ts_idx.colname} "
f"has type {self.primary_ts_idx.dataType}"
)

def rangeExpr(self, reverse: bool = False) -> Column:
Expand All @@ -431,7 +429,9 @@ def __init__(
super().__init__(ts_idx, src_str_col, parsed_col)
if not isinstance(self.primary_ts_idx.dataType, DateType):
raise TypeError(
f"ParsedDateIndex must be of DateType, but given ts_col {self.primary_ts_idx.colname} has type {self.primary_ts_idx.dataType}"
f"ParsedDateIndex must be of DateType, "
f"but given ts_col {self.primary_ts_idx.colname} "
f"has type {self.primary_ts_idx.dataType}"
)

def rangeExpr(self, reverse: bool = False) -> Column:
Expand Down

0 comments on commit a5cb399

Please sign in to comment.