diff --git a/src/worldcereal/utils/refdata.py b/src/worldcereal/utils/refdata.py index 15df662e..6fcb3e68 100644 --- a/src/worldcereal/utils/refdata.py +++ b/src/worldcereal/utils/refdata.py @@ -8,6 +8,7 @@ import pandas as pd import requests from loguru import logger +from openeo_gfmap import TemporalContext from shapely.geometry import Polygon from worldcereal.data import croptype_mappings @@ -28,7 +29,10 @@ def get_class_mappings() -> Dict: def query_public_extractions( - bbox_poly: Polygon, buffer: int = 250000, filter_cropland: bool = True + bbox_poly: Polygon, + buffer: int = 250000, + filter_cropland: bool = True, + processing_period: TemporalContext = None, ) -> pd.DataFrame: """Function that queries the WorldCereal global database of pre-extracted input data for a given area. @@ -41,13 +45,25 @@ def query_public_extractions( buffer (in meters) to apply to the requested area, by default 250000 filter_cropland : bool, optional limit the query to samples on cropland only, by default True + processing_period : TemporalContext, optional + user-defined temporal extent to align the samples with, by default None, + which means that 12-month processing window will be aligned around each sample's original valid_date. Returns ------- pd.DataFrame DataFrame containing the extractions matching the request. """ - + from IPython.display import Markdown + + nodata_helper_message = f""" +### What to do? +1. **Increase the buffer size**: Try increasing the buffer size by passing the `buffer` parameter to the `query_public_extractions` function (to a reasonable extent). + *Current setting is: {buffer} m².* +2. **Consult the WorldCereal Reference Data Module portal**: Assess data density in the selected region by visiting the [WorldCereal Reference Data Module portal](https://ewoc-rdm-ui.iiasa.ac.at/map). +3. **Pick another area**: Consult RDM portal (see above) to find areas with more data density. +4. **Contribute data**: Collect some data and contribute to our global database! 🌍🌾 [Learn how to contribute here.](https://worldcereal.github.io/worldcereal-documentation/rdm/upload.html) +""" logger.info(f"Applying a buffer of {int(buffer/1000)} km to the selected area ...") bbox_poly = ( @@ -77,9 +93,12 @@ def query_public_extractions( if len(ref_ids_lst) == 0: logger.error( - "No datasets found in WorldCereal global extractions database that intersect with the selected area." + "No datasets found in the WorldCereal global extractions database that intersect with the selected area." + ) + Markdown(nodata_helper_message) + raise ValueError( + "No datasets found in the WorldCereal global extractions database that intersect with the selected area." ) - raise ValueError() logger.info( f"Found {len(ref_ids_lst)} datasets in WorldCereal global extractions database that intersect with the selected area." @@ -131,13 +150,127 @@ def query_public_extractions( public_df_raw = db.sql(main_query).df() + if public_df_raw.empty: + logger.error( + f"No samples from the WorldCereal global extractions database fall into the selected area with buffer {int(buffer/1000)}km2." + ) + Markdown(nodata_helper_message) + raise ValueError( + "No samples from the WorldCereal global extractions database fall into the selected area." + ) + if public_df_raw["CROPTYPE_LABEL"].nunique() == 1: + logger.error( + f"Queried data contains only one class: {public_df_raw['croptype_name'].unique()[0]}. Cannot train a model with only one class." + ) + Markdown(nodata_helper_message) + raise ValueError( + "Queried data contains only one class. Cannot train a model with only one class." + ) + # Process the parquet into the format we need for training - processed_public_df = process_parquet(public_df_raw) + processed_public_df = process_parquet(public_df_raw, processing_period) return processed_public_df -def process_parquet(public_df_raw: pd.DataFrame) -> pd.DataFrame: +def month_diff(month1: int, month2: int) -> int: + """This function computes the difference between `month1` and `month2` + assuming that `month1` is in the past relative to `month2`. + The difference is calculated such that it falls within the range of 0 to 12 months. + + Parameters + ---------- + month1 : int + The reference month (1-12). + month2 : int + The month to compare against (1-12). + + Returns + ------- + int + The difference between `month1` and `month2`. + """ + + return month2 - month1 if month2 >= month1 else 12 - month1 + month2 + + +def get_best_valid_date(row: pd.Series): + """Determine the best valid date for a given row based on forward and backward shifts. + This function checks if shifting the valid date forward or backward by a specified number of months + will fit within the existing extraction dates. It returns the new valid date based on the shifts or + NaN if neither shift is possible. + + Parameters + ---------- + row : pd.Series + A row from raw flattened dataframe from the global database that contains the following columns: + - "sample_id" (str): The unique sample identifier. + - "valid_date" (pd.Timestamp): The original valid date. + - "valid_month_shift_forward" (int): Number of months to shift forward. + - "valid_month_shift_backward" (int): Number of months to shift backward. + - "start_date" (pd.Timestamp): The start date of the extraction period. + - "end_date" (pd.Timestamp): The end date of the extraction period. + + Returns + ------- + pd.Datetime + shifted valid date + """ + + from presto.dataops import MIN_EDGE_BUFFER, NUM_TIMESTEPS + + # check if shift forward will fit into existing extractions + # allow buffer of MIN_EDGE_BUFFER months at the start and end of the extraction period + temp_end_date = row["valid_date"] + pd.DateOffset( + months=row["valid_month_shift_forward"] + NUM_TIMESTEPS // 2 - MIN_EDGE_BUFFER + ) + temp_start_date = temp_end_date - pd.DateOffset(months=NUM_TIMESTEPS) + if (temp_end_date <= row["end_date"]) & (temp_start_date >= row["start_date"]): + shift_forward_ok = True + else: + shift_forward_ok = False + + # check if shift backward will fit into existing extractions + # allow buffer of MIN_EDGE_BUFFER months at the start and end of the extraction period + temp_start_date = row["valid_date"] - pd.DateOffset( + months=row["valid_month_shift_backward"] + NUM_TIMESTEPS // 2 - MIN_EDGE_BUFFER + ) + temp_end_date = temp_start_date + pd.DateOffset(months=NUM_TIMESTEPS) + if (temp_end_date <= row["end_date"]) & (temp_start_date >= row["start_date"]): + shift_backward_ok = True + else: + shift_backward_ok = False + + if (not shift_forward_ok) & (not shift_backward_ok): + return np.nan + + if shift_forward_ok & (not shift_backward_ok): + return row["valid_date"] + pd.DateOffset( + months=row["valid_month_shift_forward"] + ) + + if (not shift_forward_ok) & shift_backward_ok: + return row["valid_date"] - pd.DateOffset( + months=row["valid_month_shift_backward"] + ) + + if shift_forward_ok & shift_backward_ok: + # if shift backward is not too much bigger than shift forward, choose backward + if ( + row["valid_month_shift_backward"] - row["valid_month_shift_forward"] + ) <= MIN_EDGE_BUFFER: + return row["valid_date"] - pd.DateOffset( + months=row["valid_month_shift_backward"] + ) + else: + return row["valid_date"] + pd.DateOffset( + months=row["valid_month_shift_forward"] + ) + + +def process_parquet( + public_df_raw: pd.DataFrame, processing_period: TemporalContext = None +) -> pd.DataFrame: """Method to transform the raw parquet data into a format that can be used for training. Includes pivoting of the dataframe and mapping of the crop types. @@ -154,7 +287,63 @@ def process_parquet(public_df_raw: pd.DataFrame) -> pd.DataFrame: from presto.utils import process_parquet as process_parquet_for_presto logger.info("Processing selected samples ...") + + if processing_period is not None: + logger.info("Aligning the samples with the user-defined temporal extent ...") + + # get the middle of the user-defined temporal extent + start_date, end_date = processing_period.to_datetime() + processing_period_middle_ts = start_date + pd.DateOffset(months=6) + processing_period_middle_month = processing_period_middle_ts.month + + # get a lighter subset with only the necessary columns + sample_dates = ( + public_df_raw[["sample_id", "start_date", "end_date", "valid_date"]] + .drop_duplicates() + .reset_index(drop=True) + ) + + # save the true valid_date for later + true_valid_date_map = sample_dates.set_index("sample_id")["valid_date"] + + # calculate the shifts and assign new valid date + sample_dates["true_valid_date_month"] = public_df_raw["valid_date"].dt.month + sample_dates["proposed_valid_date_month"] = processing_period_middle_month + sample_dates["valid_month_shift_forward"] = sample_dates.apply( + lambda xx: month_diff( + xx["proposed_valid_date_month"], xx["true_valid_date_month"] + ), + axis=1, + ) + sample_dates["valid_month_shift_backward"] = sample_dates.apply( + lambda xx: month_diff( + xx["true_valid_date_month"], xx["proposed_valid_date_month"] + ), + axis=1, + ) + sample_dates["proposed_valid_date"] = sample_dates.apply( + lambda xx: get_best_valid_date(xx), axis=1 + ) + + # remove invalid samples + invalid_samples = sample_dates.loc[ + sample_dates["proposed_valid_date"].isna(), "sample_id" + ].values + public_df_raw = public_df_raw[~public_df_raw["sample_id"].isin(invalid_samples)] + public_df_raw["valid_date"] = public_df_raw["sample_id"].map( + sample_dates.set_index("sample_id")["proposed_valid_date"] + ) + logger.warning( + f"Removed {invalid_samples.shape[0]} samples that do not fit into selected temporal extent." + ) + public_df = process_parquet_for_presto(public_df_raw) + + if processing_period is not None: + # put back the true valid_date + public_df["valid_date"] = public_df.index.map(true_valid_date_map) + public_df["valid_date"] = public_df["valid_date"].astype(str) + public_df = map_croptypes(public_df) logger.info( f"Extracted and processed {public_df.shape[0]} samples from global database." @@ -202,9 +391,9 @@ def map_croptypes( df["ewoc_code"] = df["CROPTYPE_LABEL"].map( wc2ewoc_map.set_index("croptype")["ewoc_code"] ) - df["landcover_name"] = df["ewoc_code"].map(ewoc_map["landcover_name"]) - df["cropgroup_name"] = df["ewoc_code"].map(ewoc_map["cropgroup_name"]) - df["croptype_name"] = df["ewoc_code"].map(ewoc_map["croptype_name"]) + df["label_level1"] = df["ewoc_code"].map(ewoc_map["cropland_name"]) + df["label_level2"] = df["ewoc_code"].map(ewoc_map["landcover_name"]) + df["label_level3"] = df["ewoc_code"].map(ewoc_map["croptype_name"]) df["downstream_class"] = df["ewoc_code"].map( {int(k): v for k, v in get_class_mappings()[downstream_classes].items()}