Skip to content

Commit

Permalink
Merge pull request #246 from WorldCereal/245-fix-querying-public-extr…
Browse files Browse the repository at this point in the history
…action-function

245 fix querying public extraction function
  • Loading branch information
cbutsko authored Jan 13, 2025
2 parents 2a27eb9 + de55b03 commit 99e1909
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 56 deletions.
129 changes: 74 additions & 55 deletions src/worldcereal/utils/refdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def query_public_extractions(
)

# Process the parquet into the format we need for training
processed_public_df = process_parquet(public_df_raw, processing_period)
processed_public_df = process_public_extractions_df(
public_df_raw, processing_period
)

return processed_public_df

Expand All @@ -191,7 +193,7 @@ def month_diff(month1: int, month2: int) -> int:
The difference between `month1` and `month2`.
"""

return month2 - month1 if month2 >= month1 else 12 - month1 + month2
return (month2 - month1) % 12


def get_best_valid_date(row: pd.Series):
Expand Down Expand Up @@ -219,57 +221,59 @@ def get_best_valid_date(row: pd.Series):

from presto.dataops import MIN_EDGE_BUFFER, NUM_TIMESTEPS

# check if shift forward will fit into existing extractions
# allow buffer of MIN_EDGE_BUFFER months at the start and end of the extraction period
temp_end_date = row["valid_date"] + pd.DateOffset(
months=row["valid_month_shift_forward"] + NUM_TIMESTEPS // 2 - MIN_EDGE_BUFFER
def is_within_period(proposed_date, start_date, end_date):
return (proposed_date - pd.DateOffset(months=MIN_EDGE_BUFFER) >= start_date) & (
proposed_date + pd.DateOffset(months=MIN_EDGE_BUFFER) <= end_date
)

def check_shift(proposed_date, valid_date, start_date, end_date):
proposed_start_date = proposed_date - pd.DateOffset(
months=(NUM_TIMESTEPS // 2 - 1)
)
proposed_end_date = proposed_date + pd.DateOffset(months=(NUM_TIMESTEPS // 2))
return (
is_within_period(proposed_date, start_date, end_date)
& (valid_date >= proposed_start_date)
& (valid_date <= proposed_end_date)
)

valid_date = row["valid_date"]
start_date = row["start_date"]
end_date = row["end_date"]

proposed_valid_date_fwd = valid_date + pd.DateOffset(
months=row["valid_month_shift_forward"]
)
proposed_valid_date_bwd = valid_date - pd.DateOffset(
months=row["valid_month_shift_backward"]
)
temp_start_date = temp_end_date - pd.DateOffset(months=NUM_TIMESTEPS)
if (temp_end_date <= row["end_date"]) & (temp_start_date >= row["start_date"]):
shift_forward_ok = True
else:
shift_forward_ok = False

# check if shift backward will fit into existing extractions
# allow buffer of MIN_EDGE_BUFFER months at the start and end of the extraction period
temp_start_date = row["valid_date"] - pd.DateOffset(
months=row["valid_month_shift_backward"] + NUM_TIMESTEPS // 2 - MIN_EDGE_BUFFER
shift_forward_ok = check_shift(
proposed_valid_date_fwd, valid_date, start_date, end_date
)
shift_backward_ok = check_shift(
proposed_valid_date_bwd, valid_date, start_date, end_date
)
temp_end_date = temp_start_date + pd.DateOffset(months=NUM_TIMESTEPS)
if (temp_end_date <= row["end_date"]) & (temp_start_date >= row["start_date"]):
shift_backward_ok = True
else:
shift_backward_ok = False

if (not shift_forward_ok) & (not shift_backward_ok):
if not shift_forward_ok and not shift_backward_ok:
return np.nan

if shift_forward_ok & (not shift_backward_ok):
return row["valid_date"] + pd.DateOffset(
months=row["valid_month_shift_forward"]
if shift_forward_ok and not shift_backward_ok:
return proposed_valid_date_fwd
if not shift_forward_ok and shift_backward_ok:
return proposed_valid_date_bwd
if shift_forward_ok and shift_backward_ok:
return (
proposed_valid_date_bwd
if (row["valid_month_shift_backward"] - row["valid_month_shift_forward"])
<= MIN_EDGE_BUFFER
else proposed_valid_date_fwd
)

if (not shift_forward_ok) & shift_backward_ok:
return row["valid_date"] - pd.DateOffset(
months=row["valid_month_shift_backward"]
)

if shift_forward_ok & shift_backward_ok:
# if shift backward is not too much bigger than shift forward, choose backward
if (
row["valid_month_shift_backward"] - row["valid_month_shift_forward"]
) <= MIN_EDGE_BUFFER:
return row["valid_date"] - pd.DateOffset(
months=row["valid_month_shift_backward"]
)
else:
return row["valid_date"] + pd.DateOffset(
months=row["valid_month_shift_forward"]
)


def process_parquet(
public_df_raw: pd.DataFrame, processing_period: TemporalContext = None
def process_public_extractions_df(
public_df_raw: pd.DataFrame,
processing_period: TemporalContext = None,
freq: str = "MS",
) -> pd.DataFrame:
"""Method to transform the raw parquet data into a format that can be used for
training. Includes pivoting of the dataframe and mapping of the crop types.
Expand All @@ -278,13 +282,19 @@ def process_parquet(
----------
public_df_raw : pd.DataFrame
Input raw flattened dataframe from the global database.
Returns
-------
pd.DataFrame
processed dataframe with the necessary columns for training.
processing_period: TemporalContext, optional
User-defined temporal extent to align the samples with, by default None,
which means that 12-month processing window will be aligned around each sample's original valid_date.
If provided, the processing window will be aligned with the middle of the user-defined temporal extent, according to the
following principles:
- the original valid_date of the sample should remain within the processing window
- the center of the user-defined temporal extent should be not closer than MIN_EDGE_BUFFER (by default 2 months)
to the start or end of the extraction period
freq : str, optional
Frequency of the time series, by default "MS". Provided frequency alias should be compatible with pandas.
https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-offset-aliases
"""
from presto.utils import process_parquet as process_parquet_for_presto
from presto.utils import process_parquet

logger.info("Processing selected samples ...")

Expand All @@ -293,7 +303,16 @@ def process_parquet(

# get the middle of the user-defined temporal extent
start_date, end_date = processing_period.to_datetime()
processing_period_middle_ts = start_date + pd.DateOffset(months=6)

# sanity check to make sure freq is not something we still don't support in Presto
if freq not in ["MS", "10D"]:
raise ValueError(
f"Unsupported frequency alias: {freq}. Please use 'MS' or '10D'."
)

date_range = pd.date_range(start=start_date, end=end_date, freq=freq)
middle_index = len(date_range) // 2 - 1
processing_period_middle_ts = date_range[middle_index]
processing_period_middle_month = processing_period_middle_ts.month

# get a lighter subset with only the necessary columns
Expand All @@ -309,13 +328,13 @@ def process_parquet(
# calculate the shifts and assign new valid date
sample_dates["true_valid_date_month"] = public_df_raw["valid_date"].dt.month
sample_dates["proposed_valid_date_month"] = processing_period_middle_month
sample_dates["valid_month_shift_forward"] = sample_dates.apply(
sample_dates["valid_month_shift_backward"] = sample_dates.apply(
lambda xx: month_diff(
xx["proposed_valid_date_month"], xx["true_valid_date_month"]
),
axis=1,
)
sample_dates["valid_month_shift_backward"] = sample_dates.apply(
sample_dates["valid_month_shift_forward"] = sample_dates.apply(
lambda xx: month_diff(
xx["true_valid_date_month"], xx["proposed_valid_date_month"]
),
Expand All @@ -342,7 +361,7 @@ def process_parquet(
f"Removed {invalid_samples.shape[0]} samples that do not fit into selected temporal extent."
)

public_df = process_parquet_for_presto(public_df_raw)
public_df = process_parquet(public_df_raw)

if processing_period is not None:
# put back the true valid_date
Expand Down
104 changes: 103 additions & 1 deletion tests/worldcerealtests/test_refdata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pandas as pd
from shapely.geometry import Polygon

from worldcereal.utils.refdata import query_public_extractions
from worldcereal.utils.refdata import (
get_best_valid_date,
month_diff,
query_public_extractions,
)


def test_query_public_extractions():
Expand All @@ -14,3 +19,100 @@ def test_query_public_extractions():

# Check if dataframe has samples
assert not df.empty


def test_get_best_valid_date():
def process_test_case(test_case: pd.Series) -> pd.DataFrame:
test_case_res = []
for processing_period_middle_month in range(1, 13):
test_case["true_valid_date_month"] = test_case["valid_date"].month
test_case["proposed_valid_date_month"] = processing_period_middle_month
test_case["valid_month_shift_backward"] = month_diff(
test_case["proposed_valid_date_month"],
test_case["true_valid_date_month"],
)
test_case["valid_month_shift_forward"] = month_diff(
test_case["true_valid_date_month"],
test_case["proposed_valid_date_month"],
)
proposed_valid_date = get_best_valid_date(test_case)
test_case_res.append([processing_period_middle_month, proposed_valid_date])
return pd.DataFrame(
test_case_res, columns=["proposed_valid_month", "resulting_valid_date"]
)

test_case1 = pd.Series(
{
"start_date": pd.to_datetime("2019-01-01"),
"end_date": pd.to_datetime("2019-12-01"),
"valid_date": pd.to_datetime("2019-06-01"),
}
)
test_case2 = pd.Series(
{
"start_date": pd.to_datetime("2019-01-01"),
"end_date": pd.to_datetime("2019-12-01"),
"valid_date": pd.to_datetime("2019-10-01"),
}
)
test_case3 = pd.Series(
{
"start_date": pd.to_datetime("2019-01-01"),
"end_date": pd.to_datetime("2019-12-01"),
"valid_date": pd.to_datetime("2019-03-01"),
}
)

# Process test cases
test_case1_res = process_test_case(test_case1)
test_case2_res = process_test_case(test_case2)
test_case3_res = process_test_case(test_case3)

# Asserts are valid for default MIN_EDGE_BUFFER and NUM_TIMESTEPS values
# Assertions for test case 1
assert (
test_case1_res[test_case1_res["proposed_valid_month"].isin([1, 2, 11, 12])][
"resulting_valid_date"
]
.isna()
.all()
)
assert (
test_case1_res[test_case1_res["proposed_valid_month"].isin(range(3, 11))][
"resulting_valid_date"
]
.notna()
.all()
)

# Assertions for test case 2
assert (
test_case2_res[test_case2_res["proposed_valid_month"].isin([1, 2, 3, 11, 12])][
"resulting_valid_date"
]
.isna()
.all()
)
assert (
test_case2_res[test_case2_res["proposed_valid_month"].isin(range(4, 11))][
"resulting_valid_date"
]
.notna()
.all()
)

# Assertions for test case 3
assert (
test_case3_res[
test_case3_res["proposed_valid_month"].isin([1, 2, 9, 10, 11, 12])
]["resulting_valid_date"]
.isna()
.all()
)
assert (
test_case3_res[test_case3_res["proposed_valid_month"].isin(range(3, 9))][
"resulting_valid_date"
]
.notna()
.all()
)

0 comments on commit 99e1909

Please sign in to comment.