diff --git a/src/test_benchmark.py b/src/test_benchmark.py index 97821642..cc5e7d07 100644 --- a/src/test_benchmark.py +++ b/src/test_benchmark.py @@ -11,7 +11,7 @@ import pytest from iterpy.iter import Iter from timeseriesflattener.aggregators import Aggregator, MaxAggregator, MeanAggregator -from timeseriesflattener.feature_specs.meta import LookDistance, ValueFrame +from timeseriesflattener.feature_specs.meta import ValueFrame from timeseriesflattener.feature_specs.prediction_times import PredictionTimeFrame from timeseriesflattener.feature_specs.predictor import PredictorSpec from timeseriesflattener.flattener import Flattener @@ -50,7 +50,7 @@ def _generate_benchmark_dataset( n_features: int, n_observations_per_pred_time: int, aggregations: Sequence[Literal["max", "mean"]], - lookbehinds: Sequence[LookDistance | tuple[LookDistance, LookDistance]], + lookbehinds: Sequence[dt.timedelta | tuple[dt.timedelta, dt.timedelta]], ) -> BenchmarkDataset: pred_time_df = PredictionTimeFrame( init_df=pl.LazyFrame( diff --git a/src/timeseriesflattener/_intermediary_frames.py b/src/timeseriesflattener/_intermediary_frames.py index 319f25d8..81d1aeb4 100644 --- a/src/timeseriesflattener/_intermediary_frames.py +++ b/src/timeseriesflattener/_intermediary_frames.py @@ -6,17 +6,11 @@ import polars as pl from ._frame_validator import _validate_col_name_columns_exist -from .feature_specs.default_column_names import ( - default_prediction_time_uuid_col_name, - default_timestamp_col_name, -) from .frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe if TYPE_CHECKING: from collections.abc import Sequence - from .feature_specs.meta import ValueType - if TYPE_CHECKING: import datetime as dt @@ -27,8 +21,8 @@ class TimeMaskedFrame: init_df: pl.LazyFrame value_col_names: Sequence[str] - timestamp_col_name: str = default_timestamp_col_name - prediction_time_uuid_col_name: str = default_prediction_time_uuid_col_name + timestamp_col_name: str = "timestamp" + prediction_time_uuid_col_name: str = "prediction_time_uuid" validate_cols_exist: bool = True def __post_init__(self): @@ -47,12 +41,12 @@ def collect(self) -> pl.DataFrame: class AggregatedValueFrame: df: pl.LazyFrame value_col_name: str - prediction_time_uuid_col_name: str = default_prediction_time_uuid_col_name + prediction_time_uuid_col_name: str = "prediction_time_uuid" def __post_init__(self): _validate_col_name_columns_exist(obj=self) - def fill_nulls(self, fallback: ValueType) -> AggregatedValueFrame: + def fill_nulls(self, fallback: int | float | str | None) -> AggregatedValueFrame: filled = self.df.with_columns( pl.col(self.value_col_name) .fill_null(fallback) @@ -76,7 +70,7 @@ class TimeDeltaFrame: df: pl.LazyFrame value_col_names: Sequence[str] value_timestamp_col_name: str - prediction_time_uuid_col_name: str = default_prediction_time_uuid_col_name + prediction_time_uuid_col_name: str = "prediction_time_uuid" timedelta_col_name: str = "time_from_prediction_to_value" def __post_init__(self): diff --git a/src/timeseriesflattener/feature_specs/default_column_names.py b/src/timeseriesflattener/feature_specs/default_column_names.py deleted file mode 100644 index ae387204..00000000 --- a/src/timeseriesflattener/feature_specs/default_column_names.py +++ /dev/null @@ -1,6 +0,0 @@ -from __future__ import annotations - -default_entity_id_col_name = "entity_id" -default_prediction_time_uuid_col_name = "prediction_time_uuid" -default_pred_time_col_name = "pred_timestamp" -default_timestamp_col_name = "timestamp" diff --git a/src/timeseriesflattener/feature_specs/meta.py b/src/timeseriesflattener/feature_specs/meta.py index 5eb32012..964e9ccb 100644 --- a/src/timeseriesflattener/feature_specs/meta.py +++ b/src/timeseriesflattener/feature_specs/meta.py @@ -1,31 +1,15 @@ from __future__ import annotations import datetime as dt -from collections.abc import Sequence from dataclasses import InitVar, dataclass -from typing import TYPE_CHECKING, Literal, Union +from typing import Literal import pandas as pd import polars as pl -from timeseriesflattener.feature_specs.default_column_names import default_entity_id_col_name - from .._frame_validator import _validate_col_name_columns_exist from ..frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe -if TYPE_CHECKING: - from typing_extensions import TypeAlias - - -ValueType = Union[int, float, str, None] -InitDF_T = Union[pl.LazyFrame, pl.DataFrame, pd.DataFrame] - - -LookDistance = dt.timedelta - - -LookDistances: TypeAlias = Sequence[Union[LookDistance, tuple[LookDistance, LookDistance]]] - @dataclass class ValueFrame: @@ -37,12 +21,14 @@ class ValueFrame: Additional columns containing the values of the time series. The name of the columns will be used for feature naming. """ - init_df: InitVar[InitDF_T] - entity_id_col_name: str = default_entity_id_col_name + init_df: InitVar[pl.LazyFrame | pl.DataFrame | pd.DataFrame] + entity_id_col_name: str = "entity_id" value_timestamp_col_name: str = "timestamp" coerce_to_lazy: InitVar[bool] = True - def __post_init__(self, init_df: InitDF_T, coerce_to_lazy: bool): + def __post_init__( + self, init_df: pl.LazyFrame | pl.DataFrame | pd.DataFrame, coerce_to_lazy: bool + ): if coerce_to_lazy: self.df = _anyframe_to_lazyframe(init_df) else: @@ -63,8 +49,8 @@ def collect(self) -> pl.DataFrame: @dataclass(frozen=True) class LookPeriod: - first: LookDistance - last: LookDistance + first: dt.timedelta + last: dt.timedelta def __post_init__(self): if self.first >= self.last: @@ -74,11 +60,11 @@ def __post_init__(self): def _lookdistance_to_normalised_lookperiod( - lookdistance: LookDistance | tuple[LookDistance, LookDistance], + lookdistance: dt.timedelta | tuple[dt.timedelta, dt.timedelta], direction: Literal["ahead", "behind"], ) -> LookPeriod: is_ahead = direction == "ahead" - if isinstance(lookdistance, LookDistance): + if isinstance(lookdistance, dt.timedelta): return LookPeriod( first=dt.timedelta(days=0) if is_ahead else -lookdistance, last=lookdistance if is_ahead else dt.timedelta(0), diff --git a/src/timeseriesflattener/feature_specs/outcome.py b/src/timeseriesflattener/feature_specs/outcome.py index 77f67d26..37806aab 100644 --- a/src/timeseriesflattener/feature_specs/outcome.py +++ b/src/timeseriesflattener/feature_specs/outcome.py @@ -1,12 +1,13 @@ from __future__ import annotations +import datetime as dt from dataclasses import InitVar, dataclass from typing import TYPE_CHECKING import polars as pl from .._frame_validator import _validate_col_name_columns_exist -from .meta import LookDistances, ValueFrame, ValueType, _lookdistance_to_normalised_lookperiod +from .meta import ValueFrame, _lookdistance_to_normalised_lookperiod if TYPE_CHECKING: from collections.abc import Sequence @@ -20,12 +21,14 @@ class OutcomeSpec: """Specification for an outcome. If your outcome is binary/boolean, you can use BooleanOutcomeSpec instead.""" value_frame: ValueFrame - lookahead_distances: InitVar[LookDistances] + lookahead_distances: InitVar[Sequence[dt.timedelta | tuple[dt.timedelta, dt.timedelta]]] aggregators: Sequence[Aggregator] - fallback: ValueType + fallback: int | float | str | None column_prefix: str = "outc" - def __post_init__(self, lookahead_distances: LookDistances): + def __post_init__( + self, lookahead_distances: Sequence[dt.timedelta | tuple[dt.timedelta, dt.timedelta]] + ): self.normalised_lookperiod = [ _lookdistance_to_normalised_lookperiod(lookdistance=lookdistance, direction="ahead") for lookdistance in lookahead_distances @@ -47,7 +50,7 @@ class BooleanOutcomeSpec: """ init_frame: InitVar[TimestampValueFrame] - lookahead_distances: LookDistances + lookahead_distances: Sequence[dt.timedelta | tuple[dt.timedelta, dt.timedelta]] aggregators: Sequence[Aggregator] output_name: str column_prefix: str = "outc" diff --git a/src/timeseriesflattener/feature_specs/prediction_times.py b/src/timeseriesflattener/feature_specs/prediction_times.py index 6605d411..77b2c6fa 100644 --- a/src/timeseriesflattener/feature_specs/prediction_times.py +++ b/src/timeseriesflattener/feature_specs/prediction_times.py @@ -3,21 +3,15 @@ from dataclasses import InitVar, dataclass from typing import TYPE_CHECKING +import pandas as pd import polars as pl from .._frame_validator import _validate_col_name_columns_exist from ..frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe -from .default_column_names import ( - default_entity_id_col_name, - default_pred_time_col_name, - default_prediction_time_uuid_col_name, -) if TYPE_CHECKING: from collections.abc import Sequence - from .meta import InitDF_T - @dataclass class PredictionTimeFrame: @@ -28,13 +22,15 @@ class PredictionTimeFrame: timestamp_col_name: The name of the column containing the timestamps for when to make a prediction. """ - init_df: InitVar[InitDF_T] - entity_id_col_name: str = default_entity_id_col_name - timestamp_col_name: str = default_pred_time_col_name - prediction_time_uuid_col_name: str = default_prediction_time_uuid_col_name + init_df: InitVar[pl.LazyFrame | pl.DataFrame | pd.DataFrame] + entity_id_col_name: str = "entity_id" + timestamp_col_name: str = "pred_timestamp" + prediction_time_uuid_col_name: str = "prediction_time_uuid" coerce_to_lazy: InitVar[bool] = True - def __post_init__(self, init_df: InitDF_T, coerce_to_lazy: bool): + def __post_init__( + self, init_df: pl.LazyFrame | pl.DataFrame | pd.DataFrame, coerce_to_lazy: bool + ): if coerce_to_lazy: self.df = _anyframe_to_lazyframe(init_df) else: diff --git a/src/timeseriesflattener/feature_specs/predictor.py b/src/timeseriesflattener/feature_specs/predictor.py index d6191e46..c5cef219 100644 --- a/src/timeseriesflattener/feature_specs/predictor.py +++ b/src/timeseriesflattener/feature_specs/predictor.py @@ -1,10 +1,11 @@ from __future__ import annotations +import datetime as dt from dataclasses import InitVar, dataclass from typing import TYPE_CHECKING from .._frame_validator import _validate_col_name_columns_exist -from .meta import LookDistances, ValueFrame, ValueType, _lookdistance_to_normalised_lookperiod +from .meta import ValueFrame, _lookdistance_to_normalised_lookperiod if TYPE_CHECKING: from collections.abc import Sequence @@ -25,12 +26,14 @@ class PredictorSpec: """ value_frame: ValueFrame - lookbehind_distances: InitVar[LookDistances] + lookbehind_distances: InitVar[Sequence[dt.timedelta | tuple[dt.timedelta, dt.timedelta]]] aggregators: Sequence[Aggregator] - fallback: ValueType + fallback: int | float | str | None column_prefix: str = "pred" - def __post_init__(self, lookbehind_distances: LookDistances): + def __post_init__( + self, lookbehind_distances: Sequence[dt.timedelta | tuple[dt.timedelta, dt.timedelta]] + ): self.normalised_lookperiod = [ _lookdistance_to_normalised_lookperiod(lookdistance=lookdistance, direction="behind") for lookdistance in lookbehind_distances diff --git a/src/timeseriesflattener/feature_specs/static.py b/src/timeseriesflattener/feature_specs/static.py index 9dd6fa0d..4b6bc725 100644 --- a/src/timeseriesflattener/feature_specs/static.py +++ b/src/timeseriesflattener/feature_specs/static.py @@ -1,25 +1,21 @@ from __future__ import annotations from dataclasses import InitVar, dataclass -from typing import TYPE_CHECKING +import pandas as pd import polars as pl from .._frame_validator import _validate_col_name_columns_exist from ..frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe -from .default_column_names import default_entity_id_col_name - -if TYPE_CHECKING: - from .meta import InitDF_T, ValueType @dataclass class StaticFrame: - init_df: InitVar[InitDF_T] + init_df: InitVar[pl.LazyFrame | pl.DataFrame | pd.DataFrame] - entity_id_col_name: str = default_entity_id_col_name + entity_id_col_name: str = "entity_id" - def __post_init__(self, init_df: InitDF_T): + def __post_init__(self, init_df: pl.LazyFrame | pl.DataFrame | pd.DataFrame): self.df = _anyframe_to_lazyframe(init_df) _validate_col_name_columns_exist(obj=self) self.value_col_names = [col for col in self.df.columns if col != self.entity_id_col_name] @@ -41,4 +37,4 @@ class StaticSpec: value_frame: StaticFrame column_prefix: str - fallback: ValueType + fallback: int | float | str | None diff --git a/src/timeseriesflattener/feature_specs/timedelta.py b/src/timeseriesflattener/feature_specs/timedelta.py index 1b6c54d5..ddf99aed 100644 --- a/src/timeseriesflattener/feature_specs/timedelta.py +++ b/src/timeseriesflattener/feature_specs/timedelta.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Literal from .._frame_validator import _validate_col_name_columns_exist -from .meta import ValueFrame, ValueType +from .meta import ValueFrame if TYPE_CHECKING: import polars as pl @@ -15,7 +15,7 @@ @dataclass class TimeDeltaSpec: init_frame: TimestampValueFrame - fallback: ValueType + fallback: int | float | str | None output_name: str column_prefix: str = "pred" time_format: Literal["seconds", "minutes", "hours", "days", "years"] = "days" diff --git a/src/timeseriesflattener/feature_specs/timestamp_frame.py b/src/timeseriesflattener/feature_specs/timestamp_frame.py index 6d956f9f..376ecba1 100644 --- a/src/timeseriesflattener/feature_specs/timestamp_frame.py +++ b/src/timeseriesflattener/feature_specs/timestamp_frame.py @@ -1,16 +1,12 @@ from __future__ import annotations from dataclasses import InitVar, dataclass -from typing import TYPE_CHECKING +import pandas as pd import polars as pl from .._frame_validator import _validate_col_name_columns_exist from ..frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe -from .default_column_names import default_entity_id_col_name - -if TYPE_CHECKING: - from .meta import InitDF_T @dataclass @@ -22,11 +18,11 @@ class TimestampValueFrame: value_timestamp_col_name: The name of the column containing the timestamps. Must be a string, and the column's values must be datetimes. """ - init_df: InitVar[InitDF_T] - entity_id_col_name: str = default_entity_id_col_name + init_df: InitVar[pl.LazyFrame | pl.DataFrame | pd.DataFrame] + entity_id_col_name: str = "entity_id" value_timestamp_col_name: str = "timestamp" - def __post_init__(self, init_df: InitDF_T): + def __post_init__(self, init_df: pl.LazyFrame | pl.DataFrame | pd.DataFrame): self.df = _anyframe_to_lazyframe(init_df) _validate_col_name_columns_exist(obj=self) diff --git a/src/timeseriesflattener/frame_utilities/anyframe_to_lazyframe.py b/src/timeseriesflattener/frame_utilities/anyframe_to_lazyframe.py index 3d7f9b61..99fa5cc5 100644 --- a/src/timeseriesflattener/frame_utilities/anyframe_to_lazyframe.py +++ b/src/timeseriesflattener/frame_utilities/anyframe_to_lazyframe.py @@ -1,15 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING import pandas as pd import polars as pl -if TYPE_CHECKING: - from ..feature_specs.meta import InitDF_T - -def _anyframe_to_lazyframe(init_df: InitDF_T) -> pl.LazyFrame: +def _anyframe_to_lazyframe(init_df: pl.LazyFrame | pl.DataFrame | pd.DataFrame) -> pl.LazyFrame: if isinstance(init_df, pl.LazyFrame): return init_df if isinstance(init_df, pl.DataFrame): @@ -19,5 +15,5 @@ def _anyframe_to_lazyframe(init_df: InitDF_T) -> pl.LazyFrame: raise ValueError(f"Unsupported type: {type(init_df)}.") -def _anyframe_to_eagerframe(init_df: InitDF_T) -> pl.DataFrame: +def _anyframe_to_eagerframe(init_df: pl.LazyFrame | pl.DataFrame | pd.DataFrame) -> pl.DataFrame: return _anyframe_to_lazyframe(init_df).collect() diff --git a/src/timeseriesflattener/process_spec.py b/src/timeseriesflattener/process_spec.py index 3a230f8c..2f07f0b1 100644 --- a/src/timeseriesflattener/process_spec.py +++ b/src/timeseriesflattener/process_spec.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import datetime as dt from .feature_specs.static import StaticSpec diff --git a/src/timeseriesflattener/spec_processors/temporal.py b/src/timeseriesflattener/spec_processors/temporal.py index 679758a3..aeeb8709 100644 --- a/src/timeseriesflattener/spec_processors/temporal.py +++ b/src/timeseriesflattener/spec_processors/temporal.py @@ -8,17 +8,17 @@ from iterpy.iter import Iter from .._intermediary_frames import ProcessedFrame, TimeDeltaFrame, TimeMaskedFrame +from ..feature_specs.meta import ValueFrame from ..feature_specs.outcome import BooleanOutcomeSpec, OutcomeSpec +from ..feature_specs.prediction_times import PredictionTimeFrame from ..feature_specs.predictor import PredictorSpec from ..frame_utilities._horisontally_concat import horizontally_concatenate_dfs -from ..feature_specs.prediction_times import PredictionTimeFrame -from ..feature_specs.meta import ValueFrame, InitDF_T if TYPE_CHECKING: from collections.abc import Sequence from ..aggregators import Aggregator - from ..feature_specs.meta import LookPeriod, ValueType + from ..feature_specs.meta import LookPeriod def _get_timedelta_frame( @@ -105,7 +105,9 @@ def _mask_outside_lookperiod( def _aggregate_masked_frame( - masked_frame: TimeMaskedFrame, aggregators: Sequence[Aggregator], fallback: ValueType + masked_frame: TimeMaskedFrame, + aggregators: Sequence[Aggregator], + fallback: int | float | str | None, ) -> pl.LazyFrame: aggregator_expressions = [ aggregator(value_col_name) diff --git a/src/timeseriesflattener/v1/feature_specs/group_specs.py b/src/timeseriesflattener/v1/feature_specs/group_specs.py index fcb1e2ca..e9ec383c 100644 --- a/src/timeseriesflattener/v1/feature_specs/group_specs.py +++ b/src/timeseriesflattener/v1/feature_specs/group_specs.py @@ -1,3 +1,4 @@ +from __future__ import annotations import itertools from dataclasses import dataclass from typing import Dict, List, Sequence, Tuple, Union @@ -18,10 +19,10 @@ class NamedDataframe: class V1PGSProtocol(Protocol): - lookbehind_days: Sequence[Union[float, Tuple[float, float]]] + lookbehind_days: Sequence[float | Tuple[float, float]] named_dataframes: Sequence[NamedDataframe] aggregation_fns: Sequence[AggregationFunType] - fallback: Sequence[Union[int, float, str]] + fallback: Sequence[int | float | str] prefix: str = "pred" @@ -109,8 +110,8 @@ def create_combinations(self) -> List[OutcomeSpec]: def create_feature_combinations_from_dict( - dictionary: Dict[str, Union[str, list]], -) -> List[Dict[str, Union[str, float, int]]]: + dictionary: Dict[str, str | list], +) -> List[Dict[str, str | float | int]]: """Create feature combinations from a dictionary of feature specifications. Only unpacks the top level of lists. Args: diff --git a/src/timeseriesflattener/v1/feature_specs/single_specs.py b/src/timeseriesflattener/v1/feature_specs/single_specs.py index 378b944d..bb1a8b2e 100644 --- a/src/timeseriesflattener/v1/feature_specs/single_specs.py +++ b/src/timeseriesflattener/v1/feature_specs/single_specs.py @@ -1,3 +1,4 @@ +from __future__ import annotations from dataclasses import dataclass from typing import Tuple, Union @@ -21,7 +22,7 @@ def __post_init__(self): @dataclass(frozen=True) class CoercedFloats: lookperiod: LookPeriod - fallback: Union[float, int] + fallback: float | int def can_be_coerced_losslessly_to_int(value: float) -> bool: @@ -79,7 +80,7 @@ def get_temporal_col_name( feature_base_name: str, lookperiod: LookPeriod, aggregation_fn: AggregationFunType, - fallback: Union[float, int], + fallback: float | int, ) -> str: """Get the column name for the temporal feature.""" coerced = coerce_floats(lookperiod=lookperiod, fallback=fallback) diff --git a/src/timeseriesflattener/v1/flattened_dataset.py b/src/timeseriesflattener/v1/flattened_dataset.py index 6ce455db..05ebbbfa 100644 --- a/src/timeseriesflattener/v1/flattened_dataset.py +++ b/src/timeseriesflattener/v1/flattened_dataset.py @@ -724,7 +724,7 @@ def _check_that_spec_df_timestamp_col_is_correctly_formatted(self, spec: Tempora f"{spec.feature_base_name}: Minimum timestamp is {min_timestamp} - perhaps ints were coerced to timestamps?" ) - def add_spec(self, spec: Union[Sequence[AnySpec], AnySpec]): + def add_spec(self, spec: Sequence[Union[AnySpec, AnySpec]]): """Add a specification to the flattened dataset. This adds it to a queue of unprocessed specs, which are not processed diff --git a/src/timeseriesflattener/v1/logger.py b/src/timeseriesflattener/v1/logger.py index d93a46ec..7e3088a5 100644 --- a/src/timeseriesflattener/v1/logger.py +++ b/src/timeseriesflattener/v1/logger.py @@ -2,7 +2,7 @@ import logging from pathlib import Path -from typing import Optional, Union +from typing import Optional import coloredlogs @@ -12,7 +12,7 @@ def setup_logger( name: str, level: int = logging.DEBUG, - log_file_path: Optional[Union[str, Path]] = None, + log_file_path: Optional[str | Path] = None, fmt: str = "%(asctime)s [%(levelname)s] %(message)s", ) -> logging.Logger: """ diff --git a/src/timeseriesflattener/v1/misc_utils.py b/src/timeseriesflattener/v1/misc_utils.py index 49c793d2..6d086d57 100644 --- a/src/timeseriesflattener/v1/misc_utils.py +++ b/src/timeseriesflattener/v1/misc_utils.py @@ -2,12 +2,13 @@ utilities. If this file grows, consider splitting it up. """ +from __future__ import annotations import functools import logging import os from pathlib import Path -from typing import Any, Callable, Dict, Hashable, List, Union +from typing import Any, Callable, Dict, Hashable, List import catalogue import pandas as pd @@ -84,7 +85,7 @@ def format_dict_for_printing(d: dict) -> str: ) -def load_dataset_from_file(file_path: Path, nrows: Union[int, None] = None) -> pd.DataFrame: +def load_dataset_from_file(file_path: Path, nrows: int | None = None) -> pd.DataFrame: """Load dataset from file. Handles csv and parquet files based on suffix. Args: diff --git a/src/timeseriesflattener/v1/testing/utils_for_testing.py b/src/timeseriesflattener/v1/testing/utils_for_testing.py index 9392e636..95d9a849 100644 --- a/src/timeseriesflattener/v1/testing/utils_for_testing.py +++ b/src/timeseriesflattener/v1/testing/utils_for_testing.py @@ -1,6 +1,9 @@ """Utilities for testing.""" +from __future__ import annotations + from io import StringIO -from typing import Any, List, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence + import numpy as np import pandas as pd @@ -82,7 +85,7 @@ def str_to_df( return df.loc[:, ~df.columns.str.contains("^Unnamed")] -def _get_value_cols_based_on_spec(df: pd.DataFrame, spec: AnySpec) -> Union[str, List[str]]: +def _get_value_cols_based_on_spec(df: pd.DataFrame, spec: AnySpec) -> str | List[str]: """Get value columns based on spec. Checks if multiple value columns are present.""" feature_name = spec.feature_base_name value_cols = df.columns[df.columns.str.contains(feature_name)].tolist() @@ -94,7 +97,7 @@ def _get_value_cols_based_on_spec(df: pd.DataFrame, spec: AnySpec) -> Union[str, def assert_flattened_data_as_expected( - prediction_times_df: Union[pd.DataFrame, str], + prediction_times_df: pd.DataFrame | str, output_spec: AnySpec, expected_df: Optional[pd.DataFrame] = None, expected_values: Optional[Sequence[Any]] = None,