From 144b9c6e272719c42f6a8b53a4905eb04faecb55 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Mon, 14 Oct 2024 18:31:39 +0200 Subject: [PATCH 1/9] Default training pipeline added --- scripts/spark/train_catboost.py | 528 ++++++++++++++++++++++++++++ scripts/spark/train_catboost.sh | 45 +++ src/worldcereal/train/__init__.py | 549 ++++++++++++++++++++++++++++++ 3 files changed, 1122 insertions(+) create mode 100644 scripts/spark/train_catboost.py create mode 100644 scripts/spark/train_catboost.sh create mode 100644 src/worldcereal/train/__init__.py diff --git a/scripts/spark/train_catboost.py b/scripts/spark/train_catboost.py new file mode 100644 index 00000000..d0657231 --- /dev/null +++ b/scripts/spark/train_catboost.py @@ -0,0 +1,528 @@ +import json +import sys +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from catboost import Pool +from loguru import logger +from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix + +from worldcereal.train import get_training_data + +# from worldcereal.classification.models import WorldCerealCatBoostModel +# from worldcereal.classification.weights import get_refid_weight +from worldcereal.utils.spark import get_spark_context + +MODELVERSION = "005-ft-cropland-logloss" + + +class Trainer: + def __init__(self, settings, modeldir, detector, modelversion): + self.settings = settings + self.bands = settings["bands"] + self.outputlabel = settings["outputlabel"] + self.modeldir = modeldir + self.model = None + self.detector = detector + self.minsamples = settings.get("minsamples", 3000) + self.modelversion = modelversion + + # gpu = True if spark else False + self.gpu = False + + # Create the model directory + Path(modeldir).mkdir(parents=True, exist_ok=True) + + # Create and save the config + self.create_config() + + # Input parameters + trainingfile = settings.get("trainingfile") + + # Log to a file as well + self.sink = logger.add( + modeldir / "logfile.log", + level="DEBUG", + mode="w", + ) + + logger.info("-" * 50) + logger.info("Initializing CatBoost trainer ...") + logger.info("-" * 50) + logger.info(f"Training file: {trainingfile}") + + # Load a test dataframe to derive some information + test_df = self._load_df_partial(trainingfile) + + # Get the list of features present in the training data + self.present_features = Trainer._get_training_df_features(test_df) + + def _train(self, sc, **kwargs): + """Function to train the base model""" + + # Setup the output directory + outputdir = self.modeldir + + # Check if all required features are present + self._check_features() + + # Setup the model + self.model = self._setup_model() + + # Get and check trainingdata + logger.info("Preparing training data ...") + cal_data, val_data, test_data = self._get_trainingdata( + outputdir, + settings=self.settings, + **kwargs, + ) + self._check_trainingdata(cal_data, val_data, outputdir) + + # Remap to string labels + label_mapping = self.settings["classes"] + cal_data["label"] = cal_data["output"].map(label_mapping) + val_data["label"] = val_data["output"].map(label_mapping) + test_data["label"] = test_data["output"].map(label_mapping) + + # Save processed data to disk for debugging + logger.info("Saving processed data ...") + cal_data.to_parquet(Path(outputdir) / "processed_calibration_df.parquet") + + # Save ref_id counts to config + self.config["ref_id_counts"] = {} + self.config["ref_id_counts"]["CAL"] = ( + cal_data["ref_id"].value_counts().to_dict() + ) + self.config["ref_id_counts"]["VAL"] = ( + val_data["ref_id"].value_counts().to_dict() + ) + self.config["ref_id_counts"]["TEST"] = ( + test_data["ref_id"].value_counts().to_dict() + ) + self.save_config() + + # Train the model. If on spark -> run training on executor with multiple cores + logger.info("Starting training ...") + + def _fit_helper(model, cal_data, val_data): + logger.info("Start training ...") + # Setup datapools for training + calibration_data, eval_data = self._setup_datapools(cal_data, val_data) + model.fit( + calibration_data, + eval_set=eval_data, + verbose=50, + ) + logger.info("Finished training ...") + + return model + + # Remove logger to file because otherwise + # we get serialization issues on spark + logger.remove(self.sink) + + if sc is None: + self.model = _fit_helper(self.model, cal_data, val_data) + else: + logger.info("Running training on executor ...") + cal_data_bc = sc.broadcast(cal_data) + val_data_bc = sc.broadcast(val_data) + rdd = sc.parallelize([0], numSlices=1) + self.model = rdd.map( + lambda _: _fit_helper(self.model, cal_data_bc.value, val_data_bc.value) + ).collect()[0] + + cal_data_bc.unpersist() + val_data_bc.unpersist() + + # Add the logger again + self.sink = logger.add( + self.modeldir / "logfile.log", + level="DEBUG", + ) + + # Save the model + modelname = f"PrestoDownstreamCatBoost_{self.detector}_v{self.modelversion}" + self.save_model(self.model, outputdir, modelname) + + # Test the model + self.evaluate(self.model, test_data, outputdir) + + # Plot feature importances + self._plot_feature_importance(self.model, outputdir) + + logger.success("Base model trained!") + + def train(self, sc=None, **kwargs): + # train model + self._train(sc, minsamples=self.minsamples, **kwargs) + + def _load_df(self, file): + df = pd.read_parquet(Path(file) / f"training_df_{self.outputlabel}.parquet") + + return df + + def _load_df_partial(self, infile, num_rows=100): + import pyarrow.dataset as ds + + dataset = ds.dataset(infile, format="parquet", partitioning="hive") + scanner = dataset.to_batches(batch_size=num_rows) + + # Extract rows from the scanner + rows = [] + for batch in scanner: + rows.append(batch.to_pandas()) + if ( + len(rows[0]) >= num_rows + ): # Stop once we reach the desired number of rows + break + df = pd.concat(rows) + + return df + + def evaluate(self, model, testdata, outdir, pattern=""): + from sklearn.metrics import ( + accuracy_score, + f1_score, + precision_score, + recall_score, + ) + + logger.info("Getting test results ...") + + # In test mode, all valid samples are equal + idxvalid = testdata["weight"] > 0 + inputs = testdata.loc[idxvalid, self.bands] + outputs = testdata.loc[idxvalid, "label"] + orig_outputs = testdata.loc[idxvalid, "orig_output"] + + # Run evaluation + predictions = model.predict(inputs) + + # Make sure predictions are now 1D + predictions = predictions.squeeze() + + # Convert labels to the same type + outputs = outputs.astype(str) + predictions = predictions.astype(str) + + # Make absolute confusion matrix + cm = confusion_matrix(outputs, predictions, labels=np.unique(outputs)) + disp = ConfusionMatrixDisplay(cm, display_labels=np.unique(outputs)) + _, ax = plt.subplots(figsize=(10, 10)) + disp.plot(ax=ax, cmap=plt.cm.Blues, colorbar=False) + plt.tight_layout() + plt.savefig(str(Path(outdir) / f"{pattern}CM_abs.png")) + plt.close() + + # Make relative confusion matrix + cm = confusion_matrix( + outputs, predictions, normalize="true", labels=np.unique(outputs) + ) + disp = ConfusionMatrixDisplay(cm, display_labels=np.unique(outputs)) + _, ax = plt.subplots(figsize=(10, 10)) + disp.plot(ax=ax, cmap=plt.cm.Blues, values_format=".1f", colorbar=False) + plt.tight_layout() + plt.savefig(str(Path(outdir) / f"{pattern}CM_norm.png")) + plt.close() + + # Compute evaluation metrics + metrics = {} + if len(np.unique(outputs)) == 2: + metrics["OA"] = np.round(accuracy_score(outputs, predictions), 3) + metrics["F1"] = np.round( + f1_score(outputs, predictions, pos_label=self.settings["classes"][1]), 3 + ) + metrics["Precision"] = np.round( + precision_score( + outputs, predictions, pos_label=self.settings["classes"][1] + ), + 3, + ) + metrics["Recall"] = np.round( + recall_score( + outputs, predictions, pos_label=self.settings["classes"][1] + ), + 3, + ) + else: + metrics["OA"] = np.round(accuracy_score(outputs, predictions), 3) + metrics["F1"] = np.round(f1_score(outputs, predictions, average="macro"), 3) + metrics["Precision"] = np.round( + precision_score(outputs, predictions, average="macro"), 3 + ) + metrics["Recall"] = np.round( + recall_score(outputs, predictions, average="macro"), 3 + ) + + # Write metrics to disk + with open(str(Path(outdir) / f"{pattern}metrics.txt"), "w") as f: + f.write("Test results:\n") + for key in metrics.keys(): + f.write(f"{key}: {metrics[key]}\n") + logger.info(f"{key} = {metrics[key]}") + + cm = confusion_matrix(outputs, predictions) + outputlabels = list(np.unique(outputs)) + predictlabels = list(np.unique(predictions)) + outputlabels.extend(predictlabels) + outputlabels = list(dict.fromkeys(outputlabels)) + outputlabels.sort() + cm_df = pd.DataFrame(data=cm, index=outputlabels, columns=outputlabels) + outfile = Path(outdir) / f"{pattern}confusion_matrix.txt" + cm_df.to_csv(outfile) + + datadict = { + "ori": orig_outputs.values, + "pred": predictions, + } + data = pd.DataFrame.from_dict(datadict) + count = data.groupby(["ori", "pred"]).size() + result = count.to_frame(name="count").reset_index() + outfile = Path(outdir) / f"{pattern}confusion_matrix_original_labels.txt" + result.to_csv(outfile, index=False) + + return metrics + + def save_model(self, model, outputdir, modelname): + # Both as cbm and onnx + model.save_model(Path(outputdir) / (modelname + ".cbm")) + model.save_model( + f"{Path(outputdir) / (modelname + '.onnx')}", + format="onnx", + export_parameters={ + "onnx_domain": "ai.catboost", + "onnx_model_version": 1, + "onnx_doc_string": f"Default {self.detector} model using CatBoost", + "onnx_graph_name": f"CatBoostModel_for_{self.detector}", + }, + ) + + def create_config(self): + import copy + + config = copy.deepcopy(self.settings) + config["trainingfile"] = str(config["trainingfile"]) + self.config = config + self.save_config() + + def save_config(self): + configpath = Path(self.modeldir) / "config.json" + with open(configpath, "w") as f: + json.dump(self.config, f, indent=4) + + @staticmethod + def _get_training_df_features(df): + present_features = df.columns.tolist() + + return present_features + + def _setup_model(self): + # Setup the model + from catboost import CatBoostClassifier + + logger.info("Setting up model ...") + + # Manually control class name order! + class_names = [ + self.settings["classes"][class_nr] + for class_nr in range(len(self.settings["classes"])) + ] + + model = CatBoostClassifier( + iterations=8000, + depth=8, + class_names=class_names, + random_seed=1234, + learning_rate=0.05, + early_stopping_rounds=50, + l2_leaf_reg=3, + eval_metric="Logloss", + train_dir=self.modeldir, + ) + + # Print a summary of the model + model_params = model.get_params() + model_params["train_dir"] = str(model_params["train_dir"]) + self.config["model_params"] = model_params + self.save_config() + logger.info(model_params) + + return model + + def _get_trainingdata(self, outputdir, minsamples=500, settings=None, **kwargs): + settings = self.settings if settings is None else settings + + # Get the data + cal_data, val_data, test_data = get_training_data( + self.detector, + settings, + self.bands, + logdir=outputdir, + minsamples=minsamples, + **kwargs, + ) + + return cal_data, val_data, test_data + + def _setup_datapools(self, cal_data, val_data): + # Setup dataset Pool + calibration_data = Pool( + data=cal_data[self.bands], + label=cal_data["label"], + weight=cal_data["weight"], + ) + eval_data = Pool( + data=val_data[self.bands], + label=val_data["label"], + weight=val_data["weight"], + ) + + return calibration_data, eval_data + + def _check_trainingdata(self, cal_data, val_data, outputdir): + # Run some checks + plt.hist(val_data[self.bands].values.ravel(), 100) + plt.savefig(Path(outputdir) / ("inputdist_val.png")) + plt.close() + plt.hist(cal_data[self.bands].values.ravel(), 100) + plt.savefig(Path(outputdir) / ("inputdist_cal.png")) + plt.close() + logger.info(f"Unique CAL outputs: {np.unique(cal_data['output'])}") + logger.info(f"Unique VAL outputs: {np.unique(val_data['output'])}") + logger.info(f"Unique CAL weights: {np.unique(cal_data['weight'])}") + logger.info(f"Unique VAL weights: {np.unique(val_data['weight'])}") + logger.info( + f"Mean Pos. weight: " + f"{np.mean(cal_data['weight'][cal_data['output'] == 1])}" + ) + logger.info( + f"Mean Neg. weight: " + f"{np.mean(cal_data['weight'][cal_data['output'] == 0])}" + ) + ratio_pos = np.sum(cal_data["output"] == 1) / cal_data["output"].size + logger.info(f"Ratio pos/neg outputs: {ratio_pos}") + + def _check_features(self): + present_features = [ft for ft in self.present_features] + + for band in self.bands: + if band not in present_features: + raise RuntimeError(f"Feature `{band}` not found in features.") + + def _plot_feature_importance(self, model, outputdir): + # Save feature importance plot + logger.info("Plotting feature importance ...") + ft_imp = model.get_feature_importance() + sorting = np.argsort(np.array(ft_imp))[::-1] + + f, ax = plt.subplots(1, 1, figsize=(20, 8)) + ax.bar(np.array(self.bands)[sorting], np.array(ft_imp)[sorting]) + ax.set_xticklabels(np.array(self.bands)[sorting], rotation=90) + plt.tight_layout() + plt.savefig(str(Path(outputdir) / "feature_importance.png")) + + @staticmethod + def write_log_df(metrics, aez, modelname, cal_data, outputdir, parentmetrics=None): + outfile = Path(outputdir) / "log_df.csv" + + nr_cal_samples = cal_data[0].shape[0] + logdata = { + "model": [modelname], + "aez": [aez], + "cal_samples": [nr_cal_samples], + "OA": [metrics["OA"]], + "OA_parent": [np.nan], + "F1": [metrics["F1"]], + "F1_parent": [np.nan], + "Precision": [metrics["Precision"]], + "Precision_parent": [np.nan], + "Recall": [metrics["Recall"]], + "Recall_parent": [np.nan], + } + + if parentmetrics is not None: + logdata["OA_parent"] = [parentmetrics["OA"]] + logdata["F1_parent"] = [parentmetrics["F1"]] + logdata["Precision_parent"] = [parentmetrics["Precision"]] + logdata["Recall_parent"] = [parentmetrics["Recall"]] + + log_df = pd.DataFrame.from_dict(logdata).set_index("model") + log_df.to_csv(outfile) + + @staticmethod + def load_log_df(outputdir): + outfile = Path(outputdir) / "log_df.csv" + if not outfile.is_file(): + raise FileNotFoundError(f"Logfile `{outfile}` not found.") + + log_df = pd.read_csv(outfile, index_col=0) + + return log_df + + +def main(detector, trainingsettings, outdir_base, MODELVERSION, sc=None): + # Plot without display + plt.switch_backend("Agg") + + logger.info(f'Training on bands: {trainingsettings["bands"]}') + + # Get path to output model directory + modeldir = Path(outdir_base) + + # Initialize trainer + trainer = Trainer(trainingsettings, modeldir, detector, MODELVERSION) + + # Train the model; + trainer.train(sc=sc) + + logger.success("Model trained!") + + +if __name__ == "__main__": + spark = False + localspark = False + + if spark: + logger.info("Setting up spark ...") + sc = get_spark_context(localspark=localspark) + else: + sc = None + + # Supress debug messages + logger.remove() + logger.add(sys.stderr, level="INFO") + + # Get the trainingsettings + BANDS_CROPLAND_PRESTO = [f"presto_ft_{i}" for i in range(128)] + trainingdir = Path( + "/vitodata/worldcereal/features/preprocessedinputs-monthly-nointerp" + ) + + trainingsettings = { + "trainingfile": trainingdir + / "training_df_presto-ss-wc-ft-ct_cropland_CROPLAND2_30D_random_time-token=none_balance=True_augment=True_presto-worldcereal.parquet", + "outputlabel": "LANDCOVER_LABEL", + "targetlabels": [11], + "ignorelabels": [10], + "focuslabels": [12, 13, 20, 30, 50, 999], + "focusmultiplier": 3, + "filter_worldcover": True, + "classes": {0: "other", 1: "cropland"}, + "bands": BANDS_CROPLAND_PRESTO, + "pos_neg_ratio": 0.45, + "minsamples": 500, + } + + # Output parameters + detector = "cropland" + outdir = ( + "/vitodata/worldcereal/models/" + f"PrestoDownstreamCatBoost/{detector}_detector_" + f"PrestoDownstreamCatBoost" + f"_v{MODELVERSION}" + ) + + main(detector, trainingsettings, outdir, MODELVERSION, sc=sc) diff --git a/scripts/spark/train_catboost.sh b/scripts/spark/train_catboost.sh new file mode 100644 index 00000000..3cc88024 --- /dev/null +++ b/scripts/spark/train_catboost.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +# shellcheck disable=SC2140 + +export SPARK_HOME=/opt/spark3_2_0/ +export PATH="$SPARK_HOME/bin:$PATH" +export PYTHONPATH=wczip + +cd src || exit +zip -r ../dist/worldcereal.zip worldcereal +cd .. + +EX_JAVAMEM='8g' +EX_PYTHONMEM='16g' +DR_JAVAMEM='8g' +DR_PYTHONMEM='16g' + +PYSPARK_PYTHON=./ewocenv/bin/python \ +${SPARK_HOME}/bin/spark-submit \ +--conf spark.yarn.appMasterEnv.PYSPARK_PYTHON="./ewocenv/bin/python" \ +--conf spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON="./ewocenv/bin/python" \ +--conf spark.executorEnv.LD_LIBRARY_PATH="./ewocenv/lib" \ +--conf spark.yarn.appMasterEnv.LD_LIBRARY_PATH="./ewocenv/lib" \ +--conf spark.yarn.appMasterEnv.PYTHONPATH=$PYTHONPATH \ +--conf spark.executorEnv.PYTHONPATH=$PYTHONPATH \ +--executor-memory ${EX_JAVAMEM} --driver-memory ${DR_JAVAMEM} \ +--conf spark.yarn.appMasterEnv.PYTHON_EGG_CACHE=./ \ +--conf spark.executorEnv.GDAL_CACHEMAX=128 \ +--conf spark.yarn.appMasterEnv.XDG_CACHE_HOME=.cache \ +--conf spark.executorEnv.XDG_CACHE_HOME=.cache \ +--conf spark.rpc.message.maxSize=1024 \ +--conf spark.speculation=false \ +--conf spark.executor.instances=1 \ +--conf spark.driver.cores=16 \ +--conf spark.executor.cores=16 \ +--conf spark.task.cpus=16 \ +--conf spark.sql.broadcastTimeout=500000 \ +--conf spark.driver.memoryOverhead=${DR_PYTHONMEM} --conf spark.executor.memoryOverhead=${EX_PYTHONMEM} \ +--conf spark.memory.fraction=0.2 \ +--conf spark.shuffle.service.enabled=false --conf spark.dynamicAllocation.enabled=false \ +--conf spark.yarn.am.waitTime=500s \ +--conf spark.driver.maxResultSize=0 \ +--master yarn --deploy-mode cluster --queue default \ +--conf spark.app.name="worldcereal-trainmodels" \ +--archives "dist/worldcereal.zip#wczip","hdfs:///tapdata/worldcereal/worldcereal_python38.tar.gz#ewocenv" \ +scripts/spark/train_catboost.py \ diff --git a/src/worldcereal/train/__init__.py b/src/worldcereal/train/__init__.py new file mode 100644 index 00000000..b7fa2389 --- /dev/null +++ b/src/worldcereal/train/__init__.py @@ -0,0 +1,549 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +from loguru import logger +from sklearn.model_selection import train_test_split +from sklearn.utils import class_weight + + +class NotEnoughSamplesError(Exception): + pass + + +def filter_perennial(df): + """Label 12: + CT label + 7900 -> LANDCOVER_LABEL 10 + 7910 -> LANDCOVER_LABEL 10 + 7920 -> LANDCOVER_LABEL 10 + 9520 -> LANDCOVER_LABEL 10 + """ + + nr_perennial_before = (df["LANDCOVER_LABEL"] == 12).sum() + logger.info(f"Perennial cropland samples before filtering: {nr_perennial_before}") + + df.loc[df["CROPTYPE_LABEL"].isin([7900, 7910, 7920, 9520]), "LANDCOVER_LABEL"] = 10 + + nr_perennial_after = (df["LANDCOVER_LABEL"] == 12).sum() + + logger.info( + ( + "Perennial cropland samples switched to annual " + f"cropland (LANDCOVER_LABEL 10): {nr_perennial_before - nr_perennial_after}" + ) + ) + + return df + + +def filter_cropland_11(df): + """Helper function to remove outliers from cropland class""" + + nr_cropland = (df["LANDCOVER_LABEL"] == 11).sum() + + logger.info(f"Annual cropland samples before filtering: {nr_cropland}") + + ignore_types = [9120, 9110, 7910, 7920, 7900, 9100] + + # Cropland should not have crop type from ignore list + df = df[ + ~((df["LANDCOVER_LABEL"] == 11) & (df["CROPTYPE_LABEL"].isin(ignore_types))) + ] + + nr_cropland = (df["LANDCOVER_LABEL"] == 11).sum() + + logger.info(f"Annual cropland samples after filtering: {nr_cropland}") + + return df + + +def filter_croptype(df): + """Helper function to remove outliers for croptype processors""" + + nr_samples = df.shape[0] + logger.info(f"Crop type samples before filtering: {nr_samples}") + + # Cropland should not have corner reflector signal in SAR + df = df[~(df["SAR-VV-p90-20m"] > 0)] + + # Cropland should not have a NDVI p90 which is very low + df = df[~(df["OPTICAL-ndvi-p90-10m"] < 0.4)] + + nr_samples = df.shape[0] + + logger.info(f"Crop type samples after filtering: {nr_samples}") + + return df + + +def filter_grassland(df): + + nr_grassland = (df["LANDCOVER_LABEL"] == 13).sum() + + logger.info(f"Grassland samples before filtering: {nr_grassland}") + + # Filter out grassland which is cropland according to worldcover + # AND potapov + df = df[ + ~( + (df["LANDCOVER_LABEL"] == 13) + & (df["WORLDCOVER-LABEL-10m"] == 40) + & (df["POTAPOV-LABEL-10m"] == 1) + ) + ] + + nr_grassland = (df["LANDCOVER_LABEL"] == 13).sum() + + logger.info(f"Grassland samples after filtering: {nr_grassland}") + + return df + + +def filter_ewoco(df): + + nr_crop = (df["LANDCOVER_LABEL"] == 11).sum() + + logger.info(f"Cropland samples before EWOCO filtering: {nr_crop}") + + # Filter out crop from ewoco + # AND potapov + df = df[~((df["LANDCOVER_LABEL"] == 11) & (df["ref_id"].str.contains("ewoco")))] + + nr_crop = (df["LANDCOVER_LABEL"] == 11).sum() + + logger.info(f"Cropland samples after EWOCO filtering: {nr_crop}") + + return df + + +def filter_otherlayers(df, detector): + """Method to filter out training labels + that do not correspond to some other values in + other land cover layers + """ + + label = "LANDCOVER_LABEL" + + # 01 - remove pixels that are urban according to worldcover + agri_labels = [10, 11, 12] + ewoc_ignore = [50] + if "annual" in detector or "cropland" in detector: + remove_idx = (df[label].isin(agri_labels)) & ( + df["WORLDCOVER-LABEL-10m"].isin(ewoc_ignore) + ) + else: # crop type case + remove_idx = df["WORLDCOVER-LABEL-10m"].isin(ewoc_ignore) + + df = df[~remove_idx].copy() + logger.info( + f"Removed {remove_idx.sum()} urban samples" + f" according to worldcover information." + ) + + # 02 - remove pixels that are not crop [10, 11] according to BOTH + # WorldCover and Potapov crop layer + remove_idx = ( + (df[label].isin([10, 11])) + & (df["WORLDCOVER-LABEL-10m"] != 40) + & (df["POTAPOV-LABEL-10m"] != 1) + ) + df = df[~remove_idx].copy() + logger.info( + f"Removed {remove_idx.sum()} crop samples" + " that are not crop according " + "to worldcover and potapov." + ) + + # 03 - remove pixels that are crop according to BOTH + # WorldCover and Potapov crop layer but not in the label + remove_idx = ( + (~df[label].isin([10, 11, 12])) + & (df["WORLDCOVER-LABEL-10m"] == 40) + & (df["POTAPOV-LABEL-10m"] == 1) + ) + df = df[~remove_idx].copy() + logger.info( + f"Removed {remove_idx.sum()} non-crop samples" + " that are crop according " + "to both worldcover and potapov." + ) + + +def get_sample_weight(sample, outputlabel, options): + + # Default weight + if sample[outputlabel] == 0: + weight = 0 + elif sample[outputlabel] in options.get("ignorelabels", []): + weight = 0 + else: + weight = 1 + + # Multiply focuslabel weight with provided multiplier + if sample[outputlabel] in options.get("focuslabels", []): + weight *= options.get("focusmultiplier", 1) + + return float(weight) + + +def process_training_data( + df, + inputfeatures, + detector, + options, + minsamples=1000, + filter_worldcover=False, + logdir=None, + outputlabel="LANDCOVER_LABEL", +): + """Function that returns inputs/outputs from training DataFrame + + Args: + df (pd.DataFrame): dataframe containing input/output features + inputfeatures (list[str]): list of inputfeatures to be used + detector (str): detector name + options (dict): dictionary containing options + minsamples (int, optional): minimum number of samples. Defaults to 500. + logdir (str, optional): output path for logs/figures. Defaults to None. + filter_worldcover (bool, optional): whether or not to + remove outliers based on worldcover data. Defaults to False. + + Raises: + NotEnoughSamplesError: obviously when not enough samples were found + + Returns: + inputs, outputs, weights: arrays to use in training + """ + + # ---------------------------------------------------------------------- + # PART II: Check sample sizes and remove labels to be ignored + + # Check if we have enough samples at all to start with + # that's at least 2X minsamples (binary classification) + if df.shape[0] < 2 * minsamples: + errormessage = ( + f"Got less than {2 * minsamples} " + f"in total for this dataset. " + "Cannot continue!" + ) + logger.error(errormessage) + raise NotEnoughSamplesError(errormessage) + + # Remove the unknown and ignore labels + if len(options.get("ignorelabels", [])) > 0: + remove_idx = (df[outputlabel] == 0) | ( + df[outputlabel].isin(options["ignorelabels"]) + ) + else: + remove_idx = df[outputlabel] == 0 + + df = df[~remove_idx].copy() + logger.info(f"Removed {remove_idx.sum()} unknown/ignore samples.") + + # ---------------------------------------------------------------------- + # PART IV: Apply various thematic filters + + if "annual" not in detector and "cropland" not in detector: + # If not looking at cropland + # need to get rid of samples that are not part of cropland or grassland + # We include grassland to make model aware of what grass looks like + # in case there is grass commission inside cropland product. + logger.info(("Removing samples that are not cropland or grassland")) + df = df[df["LANDCOVER_LABEL"].isin([10, 11, 12, 13])].copy() + + logger.info(f'Unique LANDCOVER_LABEL values: {df["LANDCOVER_LABEL"].unique()}') + + if "annual" in detector or "cropland" in detector: + # Rule 1: filter out dirty perennials + # NOTE: NEEDS TO GO FIRST ALWAYS! + df = filter_perennial(df) + + # Rule 2: filter the unkwown cropland + # df = filter_cropland_10(df) + + # # Rule 3: filter the annual cropland + df = filter_cropland_11(df) + + # Rule 4: filter the grassland + df = filter_grassland(df) + + # Rule 5: remove crop from ewoco + # df = filter_ewoco(df) + + else: + # Filter out obvious no-crop for the crop type processors + df = filter_croptype(df) + + # Apply some filters based on worldcover and potapov layers + if filter_worldcover: + filter_otherlayers(df, detector) + + # Remove corrupt rows + remove_idx = ((df.isnull())).sum(axis=1) > int(df.shape[1] * 0.75) + df = df[~remove_idx].copy() + + # do intermediate check of number of samples + if df.shape[0] < 2 * minsamples: + errormessage = ( + f"Got less than {2 * minsamples} " + f"in total for this dataset. " + "Cannot continue!" + ) + logger.error(errormessage) + raise NotEnoughSamplesError(errormessage) + + # ---------------------------------------------------------------------- + # PART VII: Select features we need + required_columns = list( + set(inputfeatures + [outputlabel, "ref_id", "location_id", "sample_id"]) + ) + df = df[required_columns] + + # ---------------------------------------------------------------------- + # PART VIII: + if "annual" in detector or "cropland" in detector: + # Remove rows with NaN + beforenan = df.shape[0] + df = df.dropna() + afternan = df.shape[0] + logger.info(f"Removed {beforenan - afternan} samples with NaN values.") + + # ---------------------------------------------------------------------- + # PART IX: Compute weights + # This happens in two parts: + # 1. Determine class weights in order to reach requested ratio + # 2. Adjust sample-specific weights based on label + + # Compute class weight to get balanced samples + binaryoutputs = df[outputlabel].copy() + binaryoutputs[binaryoutputs.isin(options["targetlabels"])] = 1 + binaryoutputs[binaryoutputs != 1] = 0 + + # check whether we still have samples from each class + if len(np.unique(binaryoutputs)) == 1: + singleclass = int(np.unique(binaryoutputs)[0]) + # Only one class present, no use to continue + raise NotEnoughSamplesError(f"Only class {singleclass} found, aborting!") + + # Compute class weights that would balance the two classes + balanced_classweight = class_weight.compute_class_weight( + class_weight="balanced", classes=np.array([0, 1]), y=binaryoutputs.values + ) + + # Adjust balanced weight according to requested ratio + pos_neg_ratio = options.get("pos_neg_ratio", 0.5) + logger.info(f"Using pos-neg-ratio of {pos_neg_ratio}") + negative_classweight = (1 - pos_neg_ratio) / 0.5 * balanced_classweight[0] + positive_classweight = pos_neg_ratio / 0.5 * balanced_classweight[1] + + # Clamp max weight to avoid excesses + MAX_WEIGHT = 10 + positive_classweight = min(positive_classweight, MAX_WEIGHT) + negative_classweight = min(negative_classweight, MAX_WEIGHT) + + weights = np.ones((binaryoutputs.shape[0])) + weights[binaryoutputs.values == 0] = negative_classweight + weights[binaryoutputs.values == 1] = positive_classweight + df["sampleweight"] = weights + + # Get the sample-specific weights + sample_weights = df.apply( + lambda row: get_sample_weight(row, outputlabel, options), axis=1 + ).values + + # Adjust sample weights by the class weights and assign as final weights + sample_weights *= df["sampleweight"].values + df["sampleweight"] = sample_weights + + # Balancing by ref_id + logger.info("Balancing for ref_ids ...") + + ref_id_classweights = class_weight.compute_class_weight( + class_weight="balanced", classes=np.unique(df["ref_id"]), y=df["ref_id"] + ) + ref_id_classweights = { + k: v for k, v in zip(np.unique(np.unique(df["ref_id"])), ref_id_classweights) + } + for ref_id in ref_id_classweights.keys(): + ref_id_classweights[ref_id] = min(ref_id_classweights[ref_id], MAX_WEIGHT) + df.loc[df["ref_id"] == ref_id, "sampleweight"] *= ref_id_classweights[ref_id] + + # ---------------------------------------------------------------------- + # PART XI: Log various things on the situation as of now + + # Log the number of samples still present + # per ref_id + ref_id_counts = ( + df.groupby("ref_id")[outputlabel].value_counts().unstack().fillna(0).astype(int) + ) + if logdir is not None: + if not (Path(logdir) / "sample_counts.csv").is_file(): + logger.info("Saving sample counts ...") + ref_id_counts.to_csv(Path(logdir) / "sample_counts.csv") + ref_id_counts = ref_id_counts.sum(axis=1).to_dict() + + if logdir is not None: + import matplotlib.pyplot as plt + + if not (Path(logdir) / "output_distribution.png").is_file(): + # Plot histogram of original outputs + outputs = df[outputlabel].copy() + counts = outputs.value_counts() + labels = counts.index.astype(int) + plt.bar(range(len(labels)), counts.values) + plt.xticks(range(len(labels)), labels, rotation=90) + plt.xlabel("Class") + plt.ylabel("Amounts") + plt.title("Output label distribution") + outfile = Path(logdir) / "output_distribution.png" + outfile.parent.mkdir(exist_ok=True) + plt.savefig(outfile) + plt.close() + + # ---------------------------------------------------------------------- + # PART XII: Extract input/output as numpy arrays, binarize the outputs + # and do one more check if we still have enough samples + + # Get the inputs + inputs = df[inputfeatures].values + + # Get the outputs AFTER the weights + outputs = df[outputlabel].copy() + + # Get the final sample weights + weights = df["sampleweight"].values + + # Get the output labels + origoutputs = np.copy(outputs.values) # For use in evaluation + + # Make classification binary + outputs[outputs.isin(options["targetlabels"])] = 1 + outputs[outputs != 1] = 0 + outputs = outputs.values + + # Now check if we have enough samples to proceed + for label in [0, 1]: + if np.sum(outputs == label) < minsamples: + errormessage = ( + f"Got less than {minsamples} " + f"`{label}` samples for this dataset. " + "Cannot continue!" + ) + logger.error(errormessage) + raise NotEnoughSamplesError(errormessage) + + # Log how many all-zero inputs (these are bad) + idx = np.where(np.sum(inputs, axis=1) == 0) + logger.info(f"#Rows with all-zero inputs: {len(idx[0])}") + + # ---------------------------------------------------------------------- + # PART XIII: Postprocessing before returning the results + + logger.info("Transforming inputs to pandas.DataFrame ...") + data = pd.DataFrame(data=inputs, columns=inputfeatures) + data["output"] = outputs + data["weight"] = weights + data["orig_output"] = origoutputs + data["ref_id"] = df["ref_id"].values + data["location_id"] = df["location_id"].values + data["sample_id"] = df["sample_id"].values + + # ---------------------------------------------------------------------- + # PART XIV: Shuffle the data + logger.info("Shuffling data ...") + data = data.sample(frac=1) + + return data + + +def get_training_data( + detector, + options, + inputfeatures, + minsamples=500, + **kwargs, +): + + dfs = ( + options["trainingfile"] + if isinstance(options["trainingfile"], list) + else [options["trainingfile"]] + ) + + df_data = pd.DataFrame() + + for current_df in dfs: + df_data = pd.concat([df_data, pd.read_parquet(current_df)]) + + # For annual cropland, need to throw out some unreliable datasets + if "cropland" in detector: + ignore_list = [ + "2017_", + "2018_AF", + "2018_SSD_WFP", + "2018_TZ_AFSIS", + "2018_TZ_RadiantEarth", + "2019_AF_OAF", + "2019_KEN_WAPOR", + "2019_TZA_CIMMYT", + "2019_TZA_OAF", + "2019_TZ_AFSIS", + "2020_RW_WAPOR-Muvu", + "2018_ES_SIGPAC-Andalucia", + "2019_ES_SIGPAC-Andalucia", + "2021_MOZ_WFP", + "2021_TZA_COPERNICUS-GEOGLAM", + ] + else: + ignore_list = [ + "2021_TZA_COPERNICUS-GEOGLAM", + "2021_UKR_sunflowermap", + "2021_EUR_EXTRACROPS", + "2017_", + ] + + logger.warning(f"Samples before removing refids: {df_data.shape[0]}") + for ignore in ignore_list: + df_data = df_data.loc[~df_data["ref_id"].str.contains(ignore)] + logger.warning(f"Samples after removing refids: {df_data.shape[0]}") + + # Get the training data using all provided options + data = process_training_data( + df_data, + inputfeatures, + detector, + options, + filter_worldcover=options.get("filter_worldcover", False), + minsamples=minsamples, + **kwargs, + ) + + # Train/test splitting should happen on location_id as we don't want + # a mix of (augmented) samples from the same location ending up in + # both training and validation/test sets + + # Step 1: Split the dataset into train + validation and test sets + samples_train, samples_test = train_test_split( + list(data["location_id"].unique()), + test_size=0.2, + random_state=42, + ) + + # Step 2: Further split the train + validation set into separate train and validation sets + samples_val, samples_test = train_test_split( + samples_test, + test_size=0.5, + random_state=42, + ) + + # Get the actual data using the splitted location_ids + data_train = data.set_index("location_id").loc[samples_train].reset_index() + data_val = data.set_index("location_id").loc[samples_val].reset_index() + data_test = data.set_index("location_id").loc[samples_test].reset_index() + + logger.info(f"Training on {data_train.shape[0]} samples.") + logger.info(f"Validating on {data_val.shape[0]} samples.") + logger.info(f"Testing on {data_test.shape[0]} samples.") + + return data_train, data_val, data_test From dc6052ffc931ebbf641099160d2a52703ae65b46 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht <66418796+kvantricht@users.noreply.github.com> Date: Tue, 15 Oct 2024 09:13:19 +0200 Subject: [PATCH 2/9] Updated comment --- src/worldcereal/train/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/worldcereal/train/__init__.py b/src/worldcereal/train/__init__.py index b7fa2389..6c0c02a6 100644 --- a/src/worldcereal/train/__init__.py +++ b/src/worldcereal/train/__init__.py @@ -476,7 +476,7 @@ def get_training_data( for current_df in dfs: df_data = pd.concat([df_data, pd.read_parquet(current_df)]) - # For annual cropland, need to throw out some unreliable datasets + # For cropland/croptype, we have to exclude some ref_ids if "cropland" in detector: ignore_list = [ "2017_", From dbf7a54d5ba57259e465cf9074b1388860f2fc6a Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Tue, 15 Oct 2024 20:07:12 +0200 Subject: [PATCH 3/9] Allow to load model from file --- src/worldcereal/utils/models.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/worldcereal/utils/models.py b/src/worldcereal/utils/models.py index 14faa5be..dbd4e28c 100644 --- a/src/worldcereal/utils/models.py +++ b/src/worldcereal/utils/models.py @@ -2,19 +2,21 @@ import json from functools import lru_cache +from pathlib import Path +from typing import Union import onnxruntime as ort import requests @lru_cache(maxsize=2) -def load_model_onnx(model_url) -> ort.InferenceSession: - """Load an ONNX model from a URL. +def load_model_onnx(model_path: Union[str, Path]) -> ort.InferenceSession: + """Load an ONNX model from a file or URL. Parameters ---------- - model_url: str - URL to the ONNX model. + model_path: Union[str, Path] + path to the ONNX model, either a local path or a public URL. Returns ------- @@ -22,8 +24,11 @@ def load_model_onnx(model_url) -> ort.InferenceSession: ONNX model loaded with ONNX runtime. """ # Two minutes timeout to download the model - response = requests.get(model_url, timeout=120) - model = response.content + if str(model_path).startswith("http"): + response = requests.get(str(model_path), timeout=120) + model = response.content + else: + model = str(model_path) return ort.InferenceSession(model) From 4c7cf8416926c5923778c175f75479698c9de239 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Tue, 15 Oct 2024 20:08:04 +0200 Subject: [PATCH 4/9] Add masking to compute presto embeddings --- src/worldcereal/train/data.py | 156 +++++++++++++++++++++++++++++++++- 1 file changed, 153 insertions(+), 3 deletions(-) diff --git a/src/worldcereal/train/data.py b/src/worldcereal/train/data.py index c84d555d..1c23e639 100644 --- a/src/worldcereal/train/data.py +++ b/src/worldcereal/train/data.py @@ -1,16 +1,166 @@ -from typing import List +from dataclasses import dataclass +from random import choice, randint, random, sample +from typing import Any, List, Tuple +import numpy as np import pandas as pd import torch from loguru import logger +from presto.dataops import BANDS_GROUPS_IDX, NUM_TIMESTEPS, SRTM_INDEX, TIMESTEPS_IDX from presto.dataset import WorldCerealLabelledDataset -from presto.masking import MaskParamsNoDw +from presto.masking import BAND_EXPANSION, MASK_STRATEGIES from presto.presto import Presto from presto.utils import device from torch.utils.data import DataLoader from tqdm import tqdm +def make_mask_no_dw( + strategy: str, + mask_ratio: float, + existing_mask: np.ndarray, + num_timesteps: int = NUM_TIMESTEPS, +) -> np.ndarray: + """ + Make a mask for a given strategy and percentage of masked values. + Args: + strategy: The masking strategy to use. One of MASK_STRATEGIES + mask_ratio: The percentage of values to mask. Between 0 and 1. + """ + # we assume that topography is never "naturally" masked + mask = existing_mask.copy() + + srtm_mask = False + + if mask_ratio > 0.05: + actual_mask_ratio = np.random.uniform(0.05, mask_ratio) + else: + actual_mask_ratio = mask_ratio + + num_tokens_to_mask = int( + ((num_timesteps * (len(BANDS_GROUPS_IDX) - 1)) + 1) * actual_mask_ratio + ) + assert num_tokens_to_mask > 0, f"num_tokens_to_mask: {num_tokens_to_mask}" + + def mask_topography(srtm_mask, num_tokens_to_mask, actual_mask_ratio): + should_flip = random() < actual_mask_ratio + if should_flip: + srtm_mask = True + num_tokens_to_mask -= 1 + return srtm_mask, num_tokens_to_mask + + def random_masking(mask, num_tokens_to_mask: int): + if num_tokens_to_mask > 0: + # we set SRTM to be True - this way, it won't get randomly assigned. + # at the end of the function, it gets properly assigned + mask[:, SRTM_INDEX] = True + # then, we flatten the mask and dw arrays + all_tokens_mask = mask.flatten() + unmasked_tokens = all_tokens_mask == 0 + # unmasked_tokens = all_tokens_mask == False + idx = np.flatnonzero(unmasked_tokens) + np.random.shuffle(idx) + idx = idx[:num_tokens_to_mask] + all_tokens_mask[idx] = True + mask = all_tokens_mask.reshape((num_timesteps, len(BANDS_GROUPS_IDX))) + return mask + + # RANDOM BANDS + if strategy == "random_combinations": + srtm_mask, num_tokens_to_mask = mask_topography( + srtm_mask, num_tokens_to_mask, actual_mask_ratio + ) + mask = random_masking(mask, num_tokens_to_mask) + + elif strategy == "group_bands": + srtm_mask, num_tokens_to_mask = mask_topography( + srtm_mask, num_tokens_to_mask, actual_mask_ratio + ) + # next, we figure out how many tokens we can mask + num_band_groups_to_mask = int(num_tokens_to_mask / num_timesteps) + assert (num_tokens_to_mask - num_timesteps * num_band_groups_to_mask) >= 0 + num_tokens_masked = 0 + # tuple because of mypy, which thinks lists can only hold one type + band_groups: List[Any] = list(range(len(BANDS_GROUPS_IDX))) + band_groups.remove(SRTM_INDEX) + band_groups_to_mask = sample(band_groups, num_band_groups_to_mask) + for band_group in band_groups_to_mask: + num_tokens_masked += int( + len(mask[:, band_group]) - sum(mask[:, band_group]) + ) + mask[:, band_group] = True + num_tokens_to_mask -= num_tokens_masked + mask = random_masking(mask, num_tokens_to_mask) + + # RANDOM TIMESTEPS + elif strategy == "random_timesteps": + srtm_mask, num_tokens_to_mask = mask_topography( + srtm_mask, num_tokens_to_mask, actual_mask_ratio + ) + # -1 for SRTM + timesteps_to_mask = int(num_tokens_to_mask / (len(BANDS_GROUPS_IDX) - 1)) + max_tokens_masked = (len(BANDS_GROUPS_IDX) - 1) * timesteps_to_mask + timesteps = sample(TIMESTEPS_IDX, k=timesteps_to_mask) + if timesteps_to_mask > 0: + num_tokens_to_mask -= int(max_tokens_masked - sum(sum(mask[timesteps]))) + mask[timesteps] = True + mask = random_masking(mask, num_tokens_to_mask) + elif strategy == "chunk_timesteps": + srtm_mask, num_tokens_to_mask = mask_topography( + srtm_mask, num_tokens_to_mask, actual_mask_ratio + ) + # -1 for SRTM + timesteps_to_mask = int(num_tokens_to_mask / (len(BANDS_GROUPS_IDX) - 1)) + if timesteps_to_mask > 0: + max_tokens_masked = (len(BANDS_GROUPS_IDX) - 1) * timesteps_to_mask + start_idx = randint(0, num_timesteps - timesteps_to_mask) + num_tokens_to_mask -= int( + max_tokens_masked + - sum(sum(mask[start_idx : start_idx + timesteps_to_mask])) + ) + mask[start_idx : start_idx + timesteps_to_mask] = True # noqa + mask = random_masking(mask, num_tokens_to_mask) + else: + raise ValueError(f"Unknown strategy {strategy} not in {MASK_STRATEGIES}") + + mask[:, SRTM_INDEX] = srtm_mask + return np.repeat(mask, BAND_EXPANSION, axis=1) + + +@dataclass +class MaskParamsNoDw: + strategies: Tuple[str, ...] = ("NDVI",) + ratio: float = 0.5 + num_timesteps: int = NUM_TIMESTEPS + + def __post_init__(self): + for strategy in self.strategies: + assert strategy in [ + "group_bands", + "random_timesteps", + "chunk_timesteps", + "random_combinations", + ] + + def mask_data( + self, eo_data: np.ndarray, mask: np.ndarray, num_timesteps: int = NUM_TIMESTEPS + ): + strategy = choice(self.strategies) + + mask = make_mask_no_dw( + strategy=strategy, + mask_ratio=self.ratio, + existing_mask=mask, + num_timesteps=num_timesteps, + ) + + x = eo_data * ~mask + y = np.zeros(eo_data.shape).astype(np.float32) + y[mask] = eo_data[mask] + + return mask, x, y, strategy + + class WorldCerealTrainingDataset(WorldCerealLabelledDataset): FILTER_LABELS = [0] @@ -42,7 +192,7 @@ def __init__( "group_bands", "random_timesteps", "chunk_timesteps", - "random_combinations", + # "random_combinations", ), mask_ratio, ) From bc27b1e097e9ee980c1326c54fe98fb5bf9efc4b Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Tue, 15 Oct 2024 20:08:51 +0200 Subject: [PATCH 5/9] Dont do weighting based on ref_id --- src/worldcereal/train/__init__.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/worldcereal/train/__init__.py b/src/worldcereal/train/__init__.py index 6c0c02a6..89c8f0ed 100644 --- a/src/worldcereal/train/__init__.py +++ b/src/worldcereal/train/__init__.py @@ -315,6 +315,8 @@ def process_training_data( # 1. Determine class weights in order to reach requested ratio # 2. Adjust sample-specific weights based on label + # df["sampleweight"] = 1 + # Compute class weight to get balanced samples binaryoutputs = df[outputlabel].copy() binaryoutputs[binaryoutputs.isin(options["targetlabels"])] = 1 @@ -356,19 +358,6 @@ def process_training_data( sample_weights *= df["sampleweight"].values df["sampleweight"] = sample_weights - # Balancing by ref_id - logger.info("Balancing for ref_ids ...") - - ref_id_classweights = class_weight.compute_class_weight( - class_weight="balanced", classes=np.unique(df["ref_id"]), y=df["ref_id"] - ) - ref_id_classweights = { - k: v for k, v in zip(np.unique(np.unique(df["ref_id"])), ref_id_classweights) - } - for ref_id in ref_id_classweights.keys(): - ref_id_classweights[ref_id] = min(ref_id_classweights[ref_id], MAX_WEIGHT) - df.loc[df["ref_id"] == ref_id, "sampleweight"] *= ref_id_classweights[ref_id] - # ---------------------------------------------------------------------- # PART XI: Log various things on the situation as of now From a17c43d9f8d642496bd6d436f5cd3f25bbecc8c3 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Tue, 15 Oct 2024 20:09:11 +0200 Subject: [PATCH 6/9] Add masking to compute embeddings --- scripts/spark/compute_presto_features.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/spark/compute_presto_features.py b/scripts/spark/compute_presto_features.py index 88320c79..97e6cf14 100644 --- a/scripts/spark/compute_presto_features.py +++ b/scripts/spark/compute_presto_features.py @@ -170,13 +170,14 @@ def main( debug = False exclude_meteo = False sample_repeats = 1 + mask_ratio = 0.5 valid_date_as_token = False presto_dir = Path("/vitodata/worldcereal/presto/finetuning") presto_model = ( presto_dir / "presto-ss-wc-ft-ct_cropland_CROPLAND2_30D_random_time-token=none_balance=True_augment=True.pt" ) - identifier = "" + identifier = f"-maxmaskratio{mask_ratio}" if spark: from worldcereal.utils.spark import get_spark_context @@ -207,6 +208,7 @@ def main( sc=sc, debug=debug, sample_repeats=sample_repeats, + mask_ratio=mask_ratio, valid_date_as_token=valid_date_as_token, exclude_meteo=exclude_meteo, ) From f38a9a0f941d7ed90b3bbcff6a9d91447e55d965 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Tue, 15 Oct 2024 20:35:02 +0200 Subject: [PATCH 7/9] Revert arg name to avoid issues --- src/worldcereal/utils/models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/worldcereal/utils/models.py b/src/worldcereal/utils/models.py index dbd4e28c..9c4f7b96 100644 --- a/src/worldcereal/utils/models.py +++ b/src/worldcereal/utils/models.py @@ -10,12 +10,12 @@ @lru_cache(maxsize=2) -def load_model_onnx(model_path: Union[str, Path]) -> ort.InferenceSession: +def load_model_onnx(model_url: Union[str, Path]) -> ort.InferenceSession: """Load an ONNX model from a file or URL. Parameters ---------- - model_path: Union[str, Path] + model_url: Union[str, Path] path to the ONNX model, either a local path or a public URL. Returns @@ -24,11 +24,11 @@ def load_model_onnx(model_path: Union[str, Path]) -> ort.InferenceSession: ONNX model loaded with ONNX runtime. """ # Two minutes timeout to download the model - if str(model_path).startswith("http"): - response = requests.get(str(model_path), timeout=120) + if str(model_url).startswith("http"): + response = requests.get(str(model_url), timeout=120) model = response.content else: - model = str(model_path) + model = str(model_url) return ort.InferenceSession(model) From 06ed51e9b4dea5d630034f1531b3a2fbb10ed771 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Tue, 15 Oct 2024 20:43:36 +0200 Subject: [PATCH 8/9] Updated model training --- scripts/spark/train_catboost.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/scripts/spark/train_catboost.py b/scripts/spark/train_catboost.py index d0657231..f7035318 100644 --- a/scripts/spark/train_catboost.py +++ b/scripts/spark/train_catboost.py @@ -10,12 +10,9 @@ from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix from worldcereal.train import get_training_data - -# from worldcereal.classification.models import WorldCerealCatBoostModel -# from worldcereal.classification.weights import get_refid_weight from worldcereal.utils.spark import get_spark_context -MODELVERSION = "005-ft-cropland-logloss" +MODELVERSION = "006-ft-cropland-maxmaskratio05" class Trainer: @@ -153,8 +150,6 @@ def _fit_helper(model, cal_data, val_data): # Plot feature importances self._plot_feature_importance(self.model, outputdir) - logger.success("Base model trained!") - def train(self, sc=None, **kwargs): # train model self._train(sc, minsamples=self.minsamples, **kwargs) @@ -339,7 +334,7 @@ def _setup_model(self): learning_rate=0.05, early_stopping_rounds=50, l2_leaf_reg=3, - eval_metric="Logloss", + eval_metric="F1", train_dir=self.modeldir, ) @@ -503,7 +498,7 @@ def main(detector, trainingsettings, outdir_base, MODELVERSION, sc=None): trainingsettings = { "trainingfile": trainingdir - / "training_df_presto-ss-wc-ft-ct_cropland_CROPLAND2_30D_random_time-token=none_balance=True_augment=True_presto-worldcereal.parquet", + / "training_df_presto-ss-wc-ft-ct_cropland_CROPLAND2_30D_random_time-token=none_balance=True_augment=True_presto-worldcereal-maxmaskratio0.5.parquet", "outputlabel": "LANDCOVER_LABEL", "targetlabels": [11], "ignorelabels": [10], @@ -512,7 +507,7 @@ def main(detector, trainingsettings, outdir_base, MODELVERSION, sc=None): "filter_worldcover": True, "classes": {0: "other", 1: "cropland"}, "bands": BANDS_CROPLAND_PRESTO, - "pos_neg_ratio": 0.45, + "pos_neg_ratio": 0.50, "minsamples": 500, } From 5add2a25b3508f423fb9003fcf4b57911814321e Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Wed, 16 Oct 2024 18:30:26 +0200 Subject: [PATCH 9/9] Update docstring --- src/worldcereal/train/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/worldcereal/train/data.py b/src/worldcereal/train/data.py index 1c23e639..d489cebb 100644 --- a/src/worldcereal/train/data.py +++ b/src/worldcereal/train/data.py @@ -25,7 +25,7 @@ def make_mask_no_dw( Make a mask for a given strategy and percentage of masked values. Args: strategy: The masking strategy to use. One of MASK_STRATEGIES - mask_ratio: The percentage of values to mask. Between 0 and 1. + mask_ratio: The max percentage of values to mask. Between 0 and 1. """ # we assume that topography is never "naturally" masked mask = existing_mask.copy()