diff --git a/presto/dataset.py b/presto/dataset.py index affe16f..2e8c9b0 100644 --- a/presto/dataset.py +++ b/presto/dataset.py @@ -818,9 +818,7 @@ def _subset_array_temporally(inarr: xr.DataArray) -> xr.DataArray: return inarr @classmethod - def nc_to_arrays( - cls, filepath: Path - ) -> Tuple[ + def nc_to_arrays(cls, filepath: Path) -> Tuple[ np.ndarray, np.ndarray, np.ndarray, diff --git a/presto/eval.py b/presto/eval.py index 4fa9aa3..7600065 100644 --- a/presto/eval.py +++ b/presto/eval.py @@ -18,12 +18,19 @@ from torch.utils.data import DataLoader, TensorDataset from tqdm import tqdm -from .dataset import (NORMED_BANDS, WorldCerealInferenceDataset, - WorldCerealLabelled10DDataset, - WorldCerealLabelledDataset) +from .dataset import ( + NORMED_BANDS, + WorldCerealInferenceDataset, + WorldCerealLabelled10DDataset, + WorldCerealLabelledDataset, +) from .hierarchical_classification import CatBoostClassifierWrapper -from .presto import (Presto, PrestoFineTuningModel, - get_sinusoid_encoding_table, param_groups_lrd) +from .presto import ( + Presto, + PrestoFineTuningModel, + get_sinusoid_encoding_table, + param_groups_lrd, +) from .utils import DEFAULT_SEED, device, get_class_mappings, prep_dataframe MIN_SAMPLES_PER_CLASS = 3 @@ -105,12 +112,12 @@ def __init__( self.num_outputs = len(train_classes) # use classes obtained from train to trim val and test classes - self.val_df.loc[ - ~self.val_df[class_column].isin(train_classes), class_column - ] = "other_crop" - self.test_df.loc[ - ~self.test_df[class_column].isin(train_classes), class_column - ] = "other_crop" + self.val_df.loc[~self.val_df[class_column].isin(train_classes), class_column] = ( + "other_crop" + ) + self.test_df.loc[~self.test_df[class_column].isin(train_classes), class_column] = ( + "other_crop" + ) # create one-hot representation from obtained labels # one-hot is needed for finetuning, diff --git a/presto/masking.py b/presto/masking.py index 4e27b4d..9aa8383 100644 --- a/presto/masking.py +++ b/presto/masking.py @@ -5,8 +5,13 @@ import numpy as np -from .dataops import (BAND_EXPANSION, BANDS_GROUPS_IDX, NUM_TIMESTEPS, - SRTM_INDEX, TIMESTEPS_IDX) +from .dataops import ( + BAND_EXPANSION, + BANDS_GROUPS_IDX, + NUM_TIMESTEPS, + SRTM_INDEX, + TIMESTEPS_IDX, +) MASK_STRATEGIES = ( "group_bands", @@ -45,10 +50,8 @@ def make_mask_no_dw( mask = existing_mask.copy() srtm_mask = False - - num_tokens_to_mask = int( - ((num_timesteps * (len(BANDS_GROUPS_IDX) - 1)) + 1) * mask_ratio - ) + + num_tokens_to_mask = int(((num_timesteps * (len(BANDS_GROUPS_IDX) - 1)) + 1) * mask_ratio) assert num_tokens_to_mask > 0 def mask_topography(srtm_mask, num_tokens_to_mask, mask_ratio): diff --git a/presto/utils.py b/presto/utils.py index 8fb4b0d..8715b56 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -12,11 +12,22 @@ import pandas as pd import torch import xarray as xr + from presto.dataops import NUM_TIMESTEPS -from .dataops import (BANDS, ERA5_BANDS, MIN_EDGE_BUFFER, NODATAVALUE, - NORMED_BANDS, REMOVED_BANDS, S1_BANDS, S1_S2_ERA5_SRTM, - S2_BANDS, SRTM_BANDS, DynamicWorld2020_2021) +from .dataops import ( + BANDS, + ERA5_BANDS, + MIN_EDGE_BUFFER, + NODATAVALUE, + NORMED_BANDS, + REMOVED_BANDS, + S1_BANDS, + S1_S2_ERA5_SRTM, + S2_BANDS, + SRTM_BANDS, + DynamicWorld2020_2021, +) # plt = None @@ -581,6 +592,7 @@ def plot_spatial( import matplotlib.colors as mcolors import matplotlib.patches as mpatches import matplotlib.pyplot as plt + from mpl_toolkits.axes_grid1 import make_axes_locatable CLASS_MAPPINGS = get_class_mappings() @@ -638,10 +650,7 @@ def plot_spatial( cmap.set_bad(color="whitesmoke") ax4.imshow(prediction_0, cmap=cmap) - patches = [ - mpatches.Patch(color=colors[ii], label=values[ii]) - for ii in range(len(values)) - ] + patches = [mpatches.Patch(color=colors[ii], label=values[ii]) for ii in range(len(values))] ax4.legend( handles=patches, bbox_to_anchor=(1.25, 0.65), diff --git a/train_finetuned.py b/train_finetuned.py index 8b0e855..1e0fdf4 100644 --- a/train_finetuned.py +++ b/train_finetuned.py @@ -12,15 +12,24 @@ import requests import torch import xarray as xr +from tqdm.auto import tqdm + from presto.dataops import NODATAVALUE from presto.dataset import WorldCerealBase, filter_remove_noncrops from presto.eval import WorldCerealEval from presto.presto import Presto -from presto.utils import (DEFAULT_SEED, config_dir, data_dir, - default_model_path, device, initialize_logging, - plot_spatial, process_parquet, seed_everything, - timestamp_dirname) -from tqdm.auto import tqdm +from presto.utils import ( + DEFAULT_SEED, + config_dir, + data_dir, + default_model_path, + device, + initialize_logging, + plot_spatial, + process_parquet, + seed_everything, + timestamp_dirname, +) logger = logging.getLogger("__main__")