Skip to content

Commit

Permalink
Formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kvantricht committed Oct 11, 2024
1 parent f96f78d commit 2ef8dde
Showing 5 changed files with 58 additions and 32 deletions.
4 changes: 1 addition & 3 deletions presto/dataset.py
Original file line number Diff line number Diff line change
@@ -818,9 +818,7 @@ def _subset_array_temporally(inarr: xr.DataArray) -> xr.DataArray:
return inarr

@classmethod
def nc_to_arrays(
cls, filepath: Path
) -> Tuple[
def nc_to_arrays(cls, filepath: Path) -> Tuple[
np.ndarray,
np.ndarray,
np.ndarray,
29 changes: 18 additions & 11 deletions presto/eval.py
Original file line number Diff line number Diff line change
@@ -18,12 +18,19 @@
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from .dataset import (NORMED_BANDS, WorldCerealInferenceDataset,
WorldCerealLabelled10DDataset,
WorldCerealLabelledDataset)
from .dataset import (
NORMED_BANDS,
WorldCerealInferenceDataset,
WorldCerealLabelled10DDataset,
WorldCerealLabelledDataset,
)
from .hierarchical_classification import CatBoostClassifierWrapper
from .presto import (Presto, PrestoFineTuningModel,
get_sinusoid_encoding_table, param_groups_lrd)
from .presto import (
Presto,
PrestoFineTuningModel,
get_sinusoid_encoding_table,
param_groups_lrd,
)
from .utils import DEFAULT_SEED, device, get_class_mappings, prep_dataframe

MIN_SAMPLES_PER_CLASS = 3
@@ -105,12 +112,12 @@ def __init__(
self.num_outputs = len(train_classes)

# use classes obtained from train to trim val and test classes
self.val_df.loc[
~self.val_df[class_column].isin(train_classes), class_column
] = "other_crop"
self.test_df.loc[
~self.test_df[class_column].isin(train_classes), class_column
] = "other_crop"
self.val_df.loc[~self.val_df[class_column].isin(train_classes), class_column] = (
"other_crop"
)
self.test_df.loc[~self.test_df[class_column].isin(train_classes), class_column] = (
"other_crop"
)

# create one-hot representation from obtained labels
# one-hot is needed for finetuning,
15 changes: 9 additions & 6 deletions presto/masking.py
Original file line number Diff line number Diff line change
@@ -5,8 +5,13 @@

import numpy as np

from .dataops import (BAND_EXPANSION, BANDS_GROUPS_IDX, NUM_TIMESTEPS,
SRTM_INDEX, TIMESTEPS_IDX)
from .dataops import (
BAND_EXPANSION,
BANDS_GROUPS_IDX,
NUM_TIMESTEPS,
SRTM_INDEX,
TIMESTEPS_IDX,
)

MASK_STRATEGIES = (
"group_bands",
@@ -45,10 +50,8 @@ def make_mask_no_dw(
mask = existing_mask.copy()

srtm_mask = False

num_tokens_to_mask = int(
((num_timesteps * (len(BANDS_GROUPS_IDX) - 1)) + 1) * mask_ratio
)

num_tokens_to_mask = int(((num_timesteps * (len(BANDS_GROUPS_IDX) - 1)) + 1) * mask_ratio)
assert num_tokens_to_mask > 0

def mask_topography(srtm_mask, num_tokens_to_mask, mask_ratio):
23 changes: 16 additions & 7 deletions presto/utils.py
Original file line number Diff line number Diff line change
@@ -12,11 +12,22 @@
import pandas as pd
import torch
import xarray as xr

from presto.dataops import NUM_TIMESTEPS

from .dataops import (BANDS, ERA5_BANDS, MIN_EDGE_BUFFER, NODATAVALUE,
NORMED_BANDS, REMOVED_BANDS, S1_BANDS, S1_S2_ERA5_SRTM,
S2_BANDS, SRTM_BANDS, DynamicWorld2020_2021)
from .dataops import (
BANDS,
ERA5_BANDS,
MIN_EDGE_BUFFER,
NODATAVALUE,
NORMED_BANDS,
REMOVED_BANDS,
S1_BANDS,
S1_S2_ERA5_SRTM,
S2_BANDS,
SRTM_BANDS,
DynamicWorld2020_2021,
)

# plt = None

@@ -581,6 +592,7 @@ def plot_spatial(
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

CLASS_MAPPINGS = get_class_mappings()

@@ -638,10 +650,7 @@ def plot_spatial(
cmap.set_bad(color="whitesmoke")

ax4.imshow(prediction_0, cmap=cmap)
patches = [
mpatches.Patch(color=colors[ii], label=values[ii])
for ii in range(len(values))
]
patches = [mpatches.Patch(color=colors[ii], label=values[ii]) for ii in range(len(values))]
ax4.legend(
handles=patches,
bbox_to_anchor=(1.25, 0.65),
19 changes: 14 additions & 5 deletions train_finetuned.py
Original file line number Diff line number Diff line change
@@ -12,15 +12,24 @@
import requests
import torch
import xarray as xr
from tqdm.auto import tqdm

from presto.dataops import NODATAVALUE
from presto.dataset import WorldCerealBase, filter_remove_noncrops
from presto.eval import WorldCerealEval
from presto.presto import Presto
from presto.utils import (DEFAULT_SEED, config_dir, data_dir,
default_model_path, device, initialize_logging,
plot_spatial, process_parquet, seed_everything,
timestamp_dirname)
from tqdm.auto import tqdm
from presto.utils import (
DEFAULT_SEED,
config_dir,
data_dir,
default_model_path,
device,
initialize_logging,
plot_spatial,
process_parquet,
seed_everything,
timestamp_dirname,
)

logger = logging.getLogger("__main__")

0 comments on commit 2ef8dde

Please sign in to comment.