Skip to content

Commit 1c18b01

Browse files
author
Tristan Nixon
committed
solving some type-check issues
1 parent 88bb03b commit 1c18b01

File tree

2 files changed

+17
-19
lines changed

2 files changed

+17
-19
lines changed

python/tempo/tsdf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
if ts_schema:
4848
self.ts_schema = ts_schema
4949
else:
50+
assert ts_col is not None
5051
self.ts_schema = TSSchema.fromDFSchema(self.df.schema, ts_col, series_ids)
5152
# validate that this schema works for this DataFrame
5253
self.ts_schema.validate(df.schema)
@@ -114,7 +115,7 @@ def fromSubsequenceCol(
114115
df: DataFrame,
115116
ts_col: str,
116117
subsequence_col: str,
117-
series_ids: Collection[str] = None,
118+
series_ids: Optional[Collection[str]] = None,
118119
) -> "TSDF":
119120
# construct a struct with the ts_col and subsequence_col
120121
struct_col_name = cls.__DEFAULT_TS_IDX_COL
@@ -132,7 +133,7 @@ def fromTimestampString(
132133
cls,
133134
df: DataFrame,
134135
ts_col: str,
135-
series_ids: Collection[str] = None,
136+
series_ids: Optional[Collection[str]] = None,
136137
ts_fmt: str = "YYYY-MM-DDThh:mm:ss[.SSSSSS]",
137138
) -> "TSDF":
138139
pass

python/tempo/tsschema.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
123123
"""
124124

125125
@abstractmethod
126-
def rangeExpr(self, reverse: bool = False) -> Column:
126+
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
127127
"""
128128
Gets an expression appropriate for performing range operations on the :class:`TSDF` records.
129129
@@ -176,7 +176,7 @@ def renamed(self, new_name: str) -> "TSIndex":
176176
self.__name = new_name
177177
return self
178178

179-
def orderByExpr(self, reverse: bool = False) -> Column:
179+
def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
180180
expr = sfn.col(self.colname)
181181
return self._reverseOrNot(expr, reverse)
182182

@@ -211,7 +211,7 @@ def __init__(self, ts_idx: StructField) -> None:
211211
def unit(self) -> Optional[TimeUnits]:
212212
return None
213213

214-
def rangeExpr(self, reverse: bool = False) -> Column:
214+
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
215215
return self.orderByExpr(reverse)
216216

217217

@@ -231,7 +231,7 @@ def __init__(self, ts_idx: StructField) -> None:
231231
def unit(self) -> Optional[TimeUnits]:
232232
return TimeUnits.SECONDS
233233

234-
def rangeExpr(self, reverse: bool = False) -> Column:
234+
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
235235
# cast timestamp to double (fractional seconds since epoch)
236236
expr = sfn.col(self.colname).cast("double")
237237
return self._reverseOrNot(expr, reverse)
@@ -253,7 +253,7 @@ def __init__(self, ts_idx: StructField) -> None:
253253
def unit(self) -> Optional[TimeUnits]:
254254
return TimeUnits.DAYS
255255

256-
def rangeExpr(self, reverse: bool = False) -> Column:
256+
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
257257
# convert date to number of days since the epoch
258258
expr = sfn.datediff(sfn.col(self.colname), sfn.lit("1970-01-01").cast("date"))
259259
return self._reverseOrNot(expr, reverse)
@@ -350,12 +350,12 @@ def ts_component(self, component_index: int) -> str:
350350
"""
351351
return self.component(self.ts_components[component_index].colname)
352352

353-
def orderByExpr(self, reverse: bool = False) -> Column:
353+
def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
354354
# build an expression for each TS component, in order
355355
exprs = [sfn.col(self.component(comp.colname)) for comp in self.ts_components]
356356
return self._reverseOrNot(exprs, reverse)
357357

358-
def rangeExpr(self, reverse: bool = False) -> Column:
358+
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
359359
return self.primary_ts_idx.rangeExpr(reverse)
360360

361361

@@ -366,7 +366,7 @@ class ParsedTSIndex(CompositeTSIndex, ABC):
366366
"""
367367

368368
def __init__(self, ts_idx: StructField, src_str_col: str, parsed_col: str) -> None:
369-
super().__init__(ts_idx, primary_ts_col=parsed_col)
369+
super().__init__(ts_idx, parsed_col)
370370
src_str_field = self.struct[src_str_col]
371371
if not isinstance(src_str_field.dataType, StringType):
372372
raise TypeError(
@@ -390,9 +390,8 @@ def validate(self, df_schema: StructType) -> None:
390390
composite_idx_type: StructType = cast(
391391
StructType, df_schema[self.colname].dataType
392392
)
393-
assert (
394-
self.__src_str_col in composite_idx_type
395-
), f"The src_str_col column {self.src_str_col} does not exist in the composite field {composite_idx_type}"
393+
assert (self.__src_str_col in composite_idx_type.fieldNames()), \
394+
f"The src_str_col column {self.src_str_col} does not exist in the composite field {composite_idx_type}"
396395
# make sure it's StringType
397396
src_str_field_type = composite_idx_type[self.__src_str_col].dataType
398397
assert isinstance(
@@ -412,7 +411,7 @@ def __init__(self, ts_idx: StructField, src_str_col: str, parsed_col: str) -> No
412411
f"ParsedTimestampIndex must be of TimestampType, but given ts_col {self.primary_ts_idx.colname} has type {self.primary_ts_idx.dataType}"
413412
)
414413

415-
def rangeExpr(self, reverse: bool = False) -> Column:
414+
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
416415
# cast timestamp to double (fractional seconds since epoch)
417416
expr = sfn.col(self.primary_ts_col).cast("double")
418417
return self._reverseOrNot(expr, reverse)
@@ -430,7 +429,7 @@ def __init__(self, ts_idx: StructField, src_str_col: str, parsed_col: str) -> No
430429
f"ParsedDateIndex must be of DateType, but given ts_col {self.primary_ts_idx.colname} has type {self.primary_ts_idx.dataType}"
431430
)
432431

433-
def rangeExpr(self, reverse: bool = False) -> Column:
432+
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
434433
# convert date to number of days since the epoch
435434
expr = sfn.datediff(
436435
sfn.col(self.primary_ts_col), sfn.lit("1970-01-01").cast("date")
@@ -522,7 +521,7 @@ class TSSchema(WindowBuilder):
522521
Schema type for a :class:`TSDF` class.
523522
"""
524523

525-
def __init__(self, ts_idx: TSIndex, series_ids: Collection[str] = None) -> None:
524+
def __init__(self, ts_idx: TSIndex, series_ids: Optional[Collection[str]]) -> None:
526525
self.__ts_idx = ts_idx
527526
if series_ids:
528527
self.__series_ids = list(series_ids)
@@ -558,9 +557,7 @@ def __str__(self) -> str:
558557
Series IDs: {self.series_ids}"""
559558

560559
@classmethod
561-
def fromDFSchema(
562-
cls, df_schema: StructType, ts_col: str, series_ids: Collection[str] = None
563-
) -> "TSSchema":
560+
def fromDFSchema(cls, df_schema: StructType, ts_col: str, series_ids: Optional[Collection[str]]) -> "TSSchema":
564561
# construct a TSIndex for the given ts_col
565562
ts_idx = SimpleTSIndex.fromTSCol(df_schema[ts_col])
566563
return cls(ts_idx, series_ids)

0 commit comments

Comments
 (0)