diff --git a/src/worldcereal/utils/refdata.py b/src/worldcereal/utils/refdata.py index bd9067f..a6a29d2 100644 --- a/src/worldcereal/utils/refdata.py +++ b/src/worldcereal/utils/refdata.py @@ -168,7 +168,9 @@ def query_public_extractions( ) # Process the parquet into the format we need for training - processed_public_df = process_parquet(public_df_raw, processing_period) + processed_public_df = process_public_extractions_df( + public_df_raw, processing_period + ) return processed_public_df @@ -191,7 +193,7 @@ def month_diff(month1: int, month2: int) -> int: The difference between `month1` and `month2`. """ - return month2 - month1 if month2 >= month1 else 12 - month1 + month2 + return (month2 - month1) % 12 def get_best_valid_date(row: pd.Series): @@ -219,57 +221,59 @@ def get_best_valid_date(row: pd.Series): from presto.dataops import MIN_EDGE_BUFFER, NUM_TIMESTEPS - # check if shift forward will fit into existing extractions - # allow buffer of MIN_EDGE_BUFFER months at the start and end of the extraction period - temp_end_date = row["valid_date"] + pd.DateOffset( - months=row["valid_month_shift_forward"] + NUM_TIMESTEPS // 2 - MIN_EDGE_BUFFER + def is_within_period(proposed_date, start_date, end_date): + return (proposed_date - pd.DateOffset(months=MIN_EDGE_BUFFER) >= start_date) & ( + proposed_date + pd.DateOffset(months=MIN_EDGE_BUFFER) <= end_date + ) + + def check_shift(proposed_date, valid_date, start_date, end_date): + proposed_start_date = proposed_date - pd.DateOffset( + months=(NUM_TIMESTEPS // 2 - 1) + ) + proposed_end_date = proposed_date + pd.DateOffset(months=(NUM_TIMESTEPS // 2)) + return ( + is_within_period(proposed_date, start_date, end_date) + & (valid_date >= proposed_start_date) + & (valid_date <= proposed_end_date) + ) + + valid_date = row["valid_date"] + start_date = row["start_date"] + end_date = row["end_date"] + + proposed_valid_date_fwd = valid_date + pd.DateOffset( + months=row["valid_month_shift_forward"] + ) + proposed_valid_date_bwd = valid_date - pd.DateOffset( + months=row["valid_month_shift_backward"] ) - temp_start_date = temp_end_date - pd.DateOffset(months=NUM_TIMESTEPS) - if (temp_end_date <= row["end_date"]) & (temp_start_date >= row["start_date"]): - shift_forward_ok = True - else: - shift_forward_ok = False - # check if shift backward will fit into existing extractions - # allow buffer of MIN_EDGE_BUFFER months at the start and end of the extraction period - temp_start_date = row["valid_date"] - pd.DateOffset( - months=row["valid_month_shift_backward"] + NUM_TIMESTEPS // 2 - MIN_EDGE_BUFFER + shift_forward_ok = check_shift( + proposed_valid_date_fwd, valid_date, start_date, end_date + ) + shift_backward_ok = check_shift( + proposed_valid_date_bwd, valid_date, start_date, end_date ) - temp_end_date = temp_start_date + pd.DateOffset(months=NUM_TIMESTEPS) - if (temp_end_date <= row["end_date"]) & (temp_start_date >= row["start_date"]): - shift_backward_ok = True - else: - shift_backward_ok = False - if (not shift_forward_ok) & (not shift_backward_ok): + if not shift_forward_ok and not shift_backward_ok: return np.nan - - if shift_forward_ok & (not shift_backward_ok): - return row["valid_date"] + pd.DateOffset( - months=row["valid_month_shift_forward"] + if shift_forward_ok and not shift_backward_ok: + return proposed_valid_date_fwd + if not shift_forward_ok and shift_backward_ok: + return proposed_valid_date_bwd + if shift_forward_ok and shift_backward_ok: + return ( + proposed_valid_date_bwd + if (row["valid_month_shift_backward"] - row["valid_month_shift_forward"]) + <= MIN_EDGE_BUFFER + else proposed_valid_date_fwd ) - if (not shift_forward_ok) & shift_backward_ok: - return row["valid_date"] - pd.DateOffset( - months=row["valid_month_shift_backward"] - ) - if shift_forward_ok & shift_backward_ok: - # if shift backward is not too much bigger than shift forward, choose backward - if ( - row["valid_month_shift_backward"] - row["valid_month_shift_forward"] - ) <= MIN_EDGE_BUFFER: - return row["valid_date"] - pd.DateOffset( - months=row["valid_month_shift_backward"] - ) - else: - return row["valid_date"] + pd.DateOffset( - months=row["valid_month_shift_forward"] - ) - - -def process_parquet( - public_df_raw: pd.DataFrame, processing_period: TemporalContext = None +def process_public_extractions_df( + public_df_raw: pd.DataFrame, + processing_period: TemporalContext = None, + freq: str = "MS", ) -> pd.DataFrame: """Method to transform the raw parquet data into a format that can be used for training. Includes pivoting of the dataframe and mapping of the crop types. @@ -278,13 +282,19 @@ def process_parquet( ---------- public_df_raw : pd.DataFrame Input raw flattened dataframe from the global database. - - Returns - ------- - pd.DataFrame - processed dataframe with the necessary columns for training. + processing_period: TemporalContext, optional + User-defined temporal extent to align the samples with, by default None, + which means that 12-month processing window will be aligned around each sample's original valid_date. + If provided, the processing window will be aligned with the middle of the user-defined temporal extent, according to the + following principles: + - the original valid_date of the sample should remain within the processing window + - the center of the user-defined temporal extent should be not closer than MIN_EDGE_BUFFER (by default 2 months) + to the start or end of the extraction period + freq : str, optional + Frequency of the time series, by default "MS". Provided frequency alias should be compatible with pandas. + https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-offset-aliases """ - from presto.utils import process_parquet as process_parquet_for_presto + from presto.utils import process_parquet logger.info("Processing selected samples ...") @@ -293,7 +303,16 @@ def process_parquet( # get the middle of the user-defined temporal extent start_date, end_date = processing_period.to_datetime() - processing_period_middle_ts = start_date + pd.DateOffset(months=6) + + # sanity check to make sure freq is not something we still don't support in Presto + if freq not in ["MS", "10D"]: + raise ValueError( + f"Unsupported frequency alias: {freq}. Please use 'MS' or '10D'." + ) + + date_range = pd.date_range(start=start_date, end=end_date, freq=freq) + middle_index = len(date_range) // 2 - 1 + processing_period_middle_ts = date_range[middle_index] processing_period_middle_month = processing_period_middle_ts.month # get a lighter subset with only the necessary columns @@ -309,13 +328,13 @@ def process_parquet( # calculate the shifts and assign new valid date sample_dates["true_valid_date_month"] = public_df_raw["valid_date"].dt.month sample_dates["proposed_valid_date_month"] = processing_period_middle_month - sample_dates["valid_month_shift_forward"] = sample_dates.apply( + sample_dates["valid_month_shift_backward"] = sample_dates.apply( lambda xx: month_diff( xx["proposed_valid_date_month"], xx["true_valid_date_month"] ), axis=1, ) - sample_dates["valid_month_shift_backward"] = sample_dates.apply( + sample_dates["valid_month_shift_forward"] = sample_dates.apply( lambda xx: month_diff( xx["true_valid_date_month"], xx["proposed_valid_date_month"] ), @@ -342,7 +361,7 @@ def process_parquet( f"Removed {invalid_samples.shape[0]} samples that do not fit into selected temporal extent." ) - public_df = process_parquet_for_presto(public_df_raw) + public_df = process_parquet(public_df_raw) if processing_period is not None: # put back the true valid_date diff --git a/tests/worldcerealtests/test_refdata.py b/tests/worldcerealtests/test_refdata.py index c5767bc..e390e62 100644 --- a/tests/worldcerealtests/test_refdata.py +++ b/tests/worldcerealtests/test_refdata.py @@ -1,6 +1,11 @@ +import pandas as pd from shapely.geometry import Polygon -from worldcereal.utils.refdata import query_public_extractions +from worldcereal.utils.refdata import ( + get_best_valid_date, + month_diff, + query_public_extractions, +) def test_query_public_extractions(): @@ -14,3 +19,100 @@ def test_query_public_extractions(): # Check if dataframe has samples assert not df.empty + + +def test_get_best_valid_date(): + def process_test_case(test_case: pd.Series) -> pd.DataFrame: + test_case_res = [] + for processing_period_middle_month in range(1, 13): + test_case["true_valid_date_month"] = test_case["valid_date"].month + test_case["proposed_valid_date_month"] = processing_period_middle_month + test_case["valid_month_shift_backward"] = month_diff( + test_case["proposed_valid_date_month"], + test_case["true_valid_date_month"], + ) + test_case["valid_month_shift_forward"] = month_diff( + test_case["true_valid_date_month"], + test_case["proposed_valid_date_month"], + ) + proposed_valid_date = get_best_valid_date(test_case) + test_case_res.append([processing_period_middle_month, proposed_valid_date]) + return pd.DataFrame( + test_case_res, columns=["proposed_valid_month", "resulting_valid_date"] + ) + + test_case1 = pd.Series( + { + "start_date": pd.to_datetime("2019-01-01"), + "end_date": pd.to_datetime("2019-12-01"), + "valid_date": pd.to_datetime("2019-06-01"), + } + ) + test_case2 = pd.Series( + { + "start_date": pd.to_datetime("2019-01-01"), + "end_date": pd.to_datetime("2019-12-01"), + "valid_date": pd.to_datetime("2019-10-01"), + } + ) + test_case3 = pd.Series( + { + "start_date": pd.to_datetime("2019-01-01"), + "end_date": pd.to_datetime("2019-12-01"), + "valid_date": pd.to_datetime("2019-03-01"), + } + ) + + # Process test cases + test_case1_res = process_test_case(test_case1) + test_case2_res = process_test_case(test_case2) + test_case3_res = process_test_case(test_case3) + + # Asserts are valid for default MIN_EDGE_BUFFER and NUM_TIMESTEPS values + # Assertions for test case 1 + assert ( + test_case1_res[test_case1_res["proposed_valid_month"].isin([1, 2, 11, 12])][ + "resulting_valid_date" + ] + .isna() + .all() + ) + assert ( + test_case1_res[test_case1_res["proposed_valid_month"].isin(range(3, 11))][ + "resulting_valid_date" + ] + .notna() + .all() + ) + + # Assertions for test case 2 + assert ( + test_case2_res[test_case2_res["proposed_valid_month"].isin([1, 2, 3, 11, 12])][ + "resulting_valid_date" + ] + .isna() + .all() + ) + assert ( + test_case2_res[test_case2_res["proposed_valid_month"].isin(range(4, 11))][ + "resulting_valid_date" + ] + .notna() + .all() + ) + + # Assertions for test case 3 + assert ( + test_case3_res[ + test_case3_res["proposed_valid_month"].isin([1, 2, 9, 10, 11, 12]) + ]["resulting_valid_date"] + .isna() + .all() + ) + assert ( + test_case3_res[test_case3_res["proposed_valid_month"].isin(range(3, 9))][ + "resulting_valid_date" + ] + .notna() + .all() + )