diff --git a/cyclops/data/aggregate.py b/cyclops/data/aggregate.py index 3dd4110b4..e0b4bade3 100644 --- a/cyclops/data/aggregate.py +++ b/cyclops/data/aggregate.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +from sklearn.base import TransformerMixin from cyclops.data.clean import dropna_rows from cyclops.data.constants import ALL, FIRST, LAST, MEAN, MEDIAN @@ -30,7 +31,7 @@ TIMESTEP = "timestep" -class Aggregator: +class Aggregator(TransformerMixin): # type: ignore """Equal-spaced aggregation, or binning, of time-series data. Computing aggregation metadata is expensive and should be done sparingly. @@ -50,13 +51,20 @@ class Aggregator: timestep_size: float Time in hours for a single timestep, or bin. window_duration: float or None - Time duration to consider after the start of a timestep. + Time in hours for the aggregation window. If None, the latest timestamp + for each time_by group is used as the window stop time. + window_start_time: pd.DataFrame or None + An optionally provided window start time for each time_by group. + window_stop_time: pd.DataFrame or None + An optionally provided window stop time for each time_by group. agg_meta_for: list of str or None Columns for which to compute aggregation metadata. window_times: pd.DataFrame or None The start/stop time windows used to aggregate the data. imputer: AggregatedImputer or None An imputer to perform aggregation. + num_timesteps: int or None + The number of timesteps in the aggregation window. """ @@ -68,6 +76,8 @@ def __init__( agg_by: Union[str, List[str]], timestep_size: Optional[int] = None, window_duration: Optional[int] = None, + window_start_time: Optional[pd.DataFrame] = None, + window_stop_time: Optional[pd.DataFrame] = None, imputer: Optional[AggregatedImputer] = None, agg_meta_for: Optional[List[str]] = None, ): @@ -81,6 +91,8 @@ def __init__( self.agg_meta_for = to_list_optional(agg_meta_for) self.timestep_size = timestep_size self.window_duration = window_duration + self.window_start_time = window_start_time + self.window_stop_time = window_stop_time self.window_times = pd.DataFrame() # Calculated when given the data self.imputer = imputer # Parameter checking @@ -92,8 +104,13 @@ def __init__( ) if window_duration is not None and timestep_size is not None: divided = window_duration / timestep_size + self.num_timesteps = int(divided) if divided != int(divided): raise ValueError("Window duration be divisible by bucket size.") + elif timestep_size is not None: + self.num_timesteps = None # type: ignore + else: + self.num_timesteps = 1 def _process_aggfuncs( self, @@ -315,8 +332,6 @@ def _compute_window_stop( def _compute_window_times( self, data: pd.DataFrame, - window_start_time: Optional[pd.DataFrame] = None, - window_stop_time: Optional[pd.DataFrame] = None, ) -> pd.DataFrame: """Compute the start/stop timestamps for each time_by window. @@ -324,10 +339,6 @@ def _compute_window_times( ---------- data: pandas.DataFrame Data before aggregation. - window_start_time: pd.DataFrame, optional - An optionally provided window start time. - window_stop_time: pd.DataFrame, optional - An optionally provided window stop time. Returns ------- @@ -338,13 +349,13 @@ def _compute_window_times( # Compute window start time window_start_time = self._compute_window_start( data, - window_start_time=window_start_time, + window_start_time=self.window_start_time, ) # Compute window stop time window_stop_time = self._compute_window_stop( data, window_start_time, - window_stop_time=window_stop_time, + window_stop_time=self.window_stop_time, ) # Combine and compute additional information window_start_time = window_start_time.rename( @@ -485,67 +496,6 @@ def _aggregate( return aggregated.set_index(self.agg_by + [TIMESTEP]) - @time_function - def __call__( - self, - data: pd.DataFrame, - window_start_time: Optional[pd.DataFrame] = None, - window_stop_time: Optional[pd.DataFrame] = None, - include_timestep_start: bool = True, - ) -> pd.DataFrame: - """Aggregate. - - The window start and stop times can be used to cut short the timeseries. - - By default, the start time of a time_by group will be the earliest - recorded timestamp in said group. Otherwise, a window_start_time - can be provided by the user to override this default. - - The end time of a time_by group work similarly, but with the additional - option of specifying a window_duration. - - Parameters - ---------- - data: pandas.DataFrame - Input data. - window_start_time: pd.DataFrame, optional - An optionally provided window start time. - window_stop_time: pd.DataFrame, optional - An optionally provided window stop time. This cannot be provided if - window_duration was set. - include_timestep_start: bool, default = True - Whether to include the window start timestamps for each timestep. - - Returns - ------- - pandas.DataFrame - The aggregated data. - - """ - # Parameter checking - if not isinstance(data, pd.DataFrame): - raise ValueError("Data to aggregate must be a pandas.DataFrame.") - has_columns( - data, - list(set([self.timestamp_col] + self.time_by + self.agg_by)), - raise_error=True, - ) - if has_columns(data, TIMESTEP): - raise ValueError(f"Input data cannot have a column called {TIMESTEP}.") - # Ensure the timestamp column is a timestamp. Drop null times (NaT). - is_timestamp_series(data[self.timestamp_col], raise_error=True) - data = dropna_rows(data, self.timestamp_col) - # Compute start/stop timestamps - self.window_times = self._compute_window_times( - data, - window_start_time=window_start_time, - window_stop_time=window_stop_time, - ) - # Restrict the data according to the start/stop - data = self._restrict_by_timestamp(data) - - return self._aggregate(data, include_timestep_start=include_timestep_start) - @time_function def vectorize(self, aggregated: pd.DataFrame) -> Vectorized: """Vectorize aggregated data. @@ -602,26 +552,41 @@ def vectorize(self, aggregated: pd.DataFrame) -> Vectorized: axis_names=["aggfuncs"] + self.agg_by + [TIMESTEP], ) - def aggregate_values( + def fit( self, data: pd.DataFrame, - window_start_time: Optional[pd.DataFrame] = None, - window_stop_time: Optional[pd.DataFrame] = None, - ) -> pd.DataFrame: - """Aggregate temporal values. + ) -> None: + """Fit the aggregator. + + Parameters + ---------- + data: pandas.DataFrame + Input data. + + """ + # Parameter checking + if not isinstance(data, pd.DataFrame): + raise ValueError("Data to aggregate must be a DataFrame.") + self.window_times = self._compute_window_times( + data, + ) - The temporal values are restricted by start/stop and then aggregated. - No timestep is created. + def transform( + self, + data: pd.DataFrame, + y: None = None, + include_timestep_start: bool = True, + ) -> pd.DataFrame: + """Transform the data by aggregating. Parameters ---------- data: pandas.DataFrame Input data. - window_start_time: pd.DataFrame, optional - An optionally provided window start time. - window_stop_time: pd.DataFrame, optional - An optionally provided window stop time. This cannot be provided if - window_duration was set. + y: None + Placeholder for sklearn compatibility. + include_timestep_start: bool, default = True + Whether to include the window start timestamps for each timestep. Returns ------- @@ -634,19 +599,42 @@ def aggregate_values( list(set([self.timestamp_col] + self.time_by + self.agg_by)), raise_error=True, ) + if has_columns(data, TIMESTEP): + raise ValueError(f"Input data cannot have a column called {TIMESTEP}.") # Ensure the timestamp column is a timestamp. Drop null times (NaT). is_timestamp_series(data[self.timestamp_col], raise_error=True) data = dropna_rows(data, self.timestamp_col) - self.window_times = self._compute_window_times( - data, - window_start_time=window_start_time, - window_stop_time=window_stop_time, - ) # Restrict the data according to the start/stop data = self._restrict_by_timestamp(data) grouped = data.groupby(self.agg_by, sort=False) - return grouped.agg(self.aggfuncs) + if self.num_timesteps == 1: + return grouped.agg(self.aggfuncs) + if self.num_timesteps is None or self.num_timesteps > 1: + return self._aggregate(data, include_timestep_start=include_timestep_start) + + raise ValueError("num_timesteps must be greater than 0.") + + def fit_transform( + self, + data: pd.DataFrame, + ) -> pd.DataFrame: + """Fit the aggregator and transform the data by aggregating. + + Parameters + ---------- + data: pandas.DataFrame + Input data. + + Returns + ------- + pandas.DataFrame + The aggregated data. + + """ + self.fit(data) + + return self.transform(data) def tabular_as_aggregated( diff --git a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb index 84d2f6269..09fa6649f 100644 --- a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb +++ b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb @@ -141,7 +141,7 @@ " patients = querier.patients()\n", " encounters = querier.mimiciv_hosp.admissions()\n", " drop_op = qo.Drop(\n", - " [\"insurance\", \"language\", \"marital_status\", \"edregtime\", \"edouttime\"],\n", + " [\"language\", \"marital_status\", \"edregtime\", \"edouttime\"],\n", " )\n", " encounters = encounters.ops(drop_op)\n", " patient_encounters = patients.join(encounters, on=\"subject_id\")\n", @@ -168,6 +168,8 @@ " \"gender\",\n", " \"anchor_year_difference\",\n", " \"admission_location\",\n", + " \"admission_type\",\n", + " \"insurance\",\n", " \"hospital_expire_flag\",\n", " ]\n", " ]\n", @@ -240,6 +242,7 @@ " \"valuenum\": \"mean\",\n", " },\n", " window_duration=M * 24,\n", + " window_start_time=start_timestamps,\n", " timestamp_col=\"charttime\",\n", " time_by=\"hadm_id\",\n", " agg_by=[\"hadm_id\", \"label\"],\n", @@ -250,9 +253,8 @@ " labevents_batch,\n", " patient_encounters,\n", " )\n", - " means = mean_aggregator.aggregate_values(\n", + " means = mean_aggregator.fit_transform(\n", " labevents_batch,\n", - " window_start_time=start_timestamps,\n", " )\n", " means = means.reset_index()\n", " means = means.pivot(index=\"hadm_id\", columns=\"label\", values=\"valuenum\")\n", diff --git a/tests/cyclops/data/test_aggregate.py b/tests/cyclops/data/test_aggregate.py index ffa1e653e..a7ca6e6ae 100644 --- a/tests/cyclops/data/test_aggregate.py +++ b/tests/cyclops/data/test_aggregate.py @@ -91,7 +91,7 @@ def test_aggregate_events( timestep_size=1, agg_meta_for=EVENT_VALUE, ) - res = aggregator(data) + res = aggregator.fit_transform(data) assert res.index.names == [ENCOUNTER_ID, EVENT_NAME, TIMESTEP] assert res.loc[(2, "eventA", 1)][EVENT_VALUE] == 19 @@ -116,8 +116,8 @@ def test_aggregate_window_duration( timestep_size=1, window_duration=12, ) + res = aggregator.fit_transform(data) - res = aggregator(data) res = res.reset_index() assert (res[TIMESTEP] < 2).all() @@ -134,13 +134,10 @@ def test_aggregate_start_stop_windows( time_by=ENCOUNTER_ID, agg_by=[ENCOUNTER_ID, EVENT_NAME], timestep_size=1, - ) - - res = aggregator( - data, window_start_time=window_start_time, window_stop_time=window_stop_time, ) + res = aggregator.fit_transform(data) assert res.loc[(2, "eventA", 0)][START_TIMESTEP] == DATE2 @@ -154,9 +151,10 @@ def test_aggregate_start_stop_windows( agg_by=[ENCOUNTER_ID, EVENT_NAME], timestep_size=1, window_duration=10, + window_stop_time=window_stop_time, ) try: - res = aggregator(data, window_stop_time=window_stop_time) + res = aggregator.fit_transform(data) raise ValueError( """Should have raised an error that window_duration cannot be set when window_stop_time is specified.""", @@ -190,7 +188,9 @@ def test_aggregate_strings( window_duration=20, ) - assert aggregator_str(data).equals(aggregator_fn(data)) + assert aggregator_str.fit_transform(data).equals( + aggregator_fn.fit_transform(data), + ) with contextlib.suppress(ValueError): aggregator_str = Aggregator( @@ -223,7 +223,7 @@ def test_aggregate_multiple( window_duration=20, ) - res = aggregator(data) + res = aggregator.fit_transform(data) res = res.reset_index() assert res["event_value2"].equals(res[EVENT_VALUE] * 2) @@ -318,7 +318,7 @@ def test_vectorization( window_duration=15, ) - aggregated = aggregator(data) + aggregated = aggregator.fit_transform(data) vectorized_obj = aggregator.vectorize(aggregated) vectorized, indexes = vectorized_obj.data, vectorized_obj.indexes