Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated WorldCerealInferenceDataset #103

Merged
merged 27 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
47034ac
Updated inference dataset
kvantricht Sep 3, 2024
d7d541e
Updated WorldCerealInferenceDataset for new files
kvantricht Sep 3, 2024
5d6fe88
Remove unused import
kvantricht Sep 3, 2024
f0b1c8c
Fix typing
kvantricht Sep 3, 2024
ce52f9c
Ignore this type check
kvantricht Sep 3, 2024
1ac7ea2
Black fixes
kvantricht Sep 3, 2024
906652e
More line length fixes
kvantricht Sep 3, 2024
1a19ec1
:facepalm
kvantricht Sep 3, 2024
8ee5b47
:facepalm:
kvantricht Sep 3, 2024
3f75632
Added ground truth labels
kvantricht Sep 3, 2024
9155516
Only subset of the file for faster tests
kvantricht Sep 3, 2024
5f6b58d
Fix gt selection and rearranged t subset
kvantricht Sep 3, 2024
794065e
test update
kvantricht Sep 3, 2024
817da45
More consistent handling of inference datasets
kvantricht Sep 4, 2024
666d05f
use h5netcdf instead of rioxarray
kvantricht Sep 4, 2024
7f14c8c
Dont import at the top to avoid dependency
kvantricht Sep 4, 2024
70015af
Add rioxarray and relax xarray version
kvantricht Sep 4, 2024
eab0569
Relax rioxarray library version
kvantricht Sep 4, 2024
3eea359
Test file update
kvantricht Sep 5, 2024
4f9a18f
Clarified comment
kvantricht Sep 5, 2024
6547114
Black fix
kvantricht Sep 5, 2024
8ac58df
Only create the grid before feeding to presto
kvantricht Sep 5, 2024
fab066c
Formatting
gabrieltseng Sep 5, 2024
6079001
Revert "Only create the grid before feeding to presto"
kvantricht Sep 5, 2024
2575d39
Merge branch 'updated-inferencedatasets' of github.com:WorldCereal/pr…
kvantricht Sep 5, 2024
be7a31b
Clarified variable naming
kvantricht Sep 5, 2024
13f8433
Black fix
kvantricht Sep 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
287 changes: 216 additions & 71 deletions presto/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
from math import modf
from pathlib import Path
from random import sample
from typing import Callable, Dict, List, Optional, Tuple, cast
from typing import Callable, Dict, List, Optional, Tuple, Union

import geopandas as gpd
import numpy as np
import pandas as pd
import rioxarray
import xarray as xr
from einops import rearrange, repeat
from pyproj import Transformer
from sklearn.utils.class_weight import compute_class_weight
from einops import rearrange
from pyproj import CRS, Transformer
from torch.utils.data import Dataset

from .dataops import (
Expand Down Expand Up @@ -320,6 +318,8 @@ def __getitem__(self, idx):

@property
def class_weights(self) -> np.ndarray:
from sklearn.utils.class_weight import compute_class_weight

if self._class_weights is None:
ys = []
for _, row in self.df.iterrows():
Expand Down Expand Up @@ -368,23 +368,7 @@ def __getitem__(self, idx):

class WorldCerealInferenceDataset(Dataset):
_NODATAVALUE = 65535
Y = "worldcereal_cropland"
BAND_MAPPING = {
"B02": "B2",
"B03": "B3",
"B04": "B4",
"B05": "B5",
"B06": "B6",
"B07": "B7",
"B08": "B8",
# B8A is missing
"B11": "B11",
"B12": "B12",
"VH": "VH",
"VV": "VV",
"precipitation-flux": "total_precipitation",
"temperature-mean": "temperature_2m",
}
Y = "WORLDCEREAL_TEMPORARYCROPS_2021"

def __init__(self, path_to_files: Path = data_dir / "inference_areas"):
self.path_to_files = path_to_files
Expand All @@ -394,72 +378,233 @@ def __len__(self):
return len(self.all_files)

@classmethod
def nc_to_arrays(
cls, filepath: Path
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
ds = cast(xr.Dataset, rioxarray.open_rasterio(filepath, decode_times=False))
epsg_coords = ds.rio.crs.to_epsg()

num_instances = len(ds.x) * len(ds.y)
num_timesteps = len(ds.t)
eo_data = np.zeros((num_instances, num_timesteps, len(BANDS)))
mask = np.zeros((num_instances, num_timesteps, len(BANDS_GROUPS_IDX)))
# for now, B8A is missing
mask[:, :, IDX_TO_BAND_GROUPS["B8A"]] = 1

for org_band, presto_val in cls.BAND_MAPPING.items():
# flatten the values
values = np.swapaxes(ds[org_band].values.reshape((num_timesteps, -1)), 0, 1)
idx_valid = values != cls._NODATAVALUE
def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]:
"""
Extracts EO data and mask arrays from the input xarray.DataArray.

Args:
inarr (xr.DataArray): Input xarray.DataArray containing EO data.

Returns:
Tuple[np.ndarray, np.ndarray]: Tuple containing EO data array and mask array.
"""
num_pixels = len(inarr.x) * len(inarr.y)
num_timesteps = len(inarr.t)

# Handle NaN values in Presto compatible way
inarr = inarr.astype(np.float32)
inarr = inarr.fillna(65535)

eo_data = np.zeros((num_pixels, num_timesteps, len(BANDS)))
mask = np.zeros((num_pixels, num_timesteps, len(BANDS_GROUPS_IDX)))

for presto_band in NORMED_BANDS:
if presto_band in inarr.coords["bands"]:
values = np.swapaxes(
inarr.sel(bands=presto_band).values.reshape((num_timesteps, -1)),
0,
1,
)
idx_valid = values != cls._NODATAVALUE
values = cls._preprocess_band_values(values, presto_band)
eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid
elif presto_band == "NDVI":
# # NDVI will be computed by the normalize function
continue
else:
logger.warning(f"Band {presto_band} not found in input data.")
eo_data[:, :, BANDS.index(presto_band)] = 0
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1

if presto_val in ["VV", "VH"]:
# convert to dB
values = 20 * np.log10(values) - 83
elif presto_val == "total_precipitation":
# scaling, and AgERA5 is in mm, Presto expects m
values = values / (100 * 1000.0)
elif presto_val == "temperature_2m":
# remove scaling
values = values / 100
return eo_data, mask

@staticmethod
def _extract_latlons(inarr: xr.DataArray, epsg: int) -> np.ndarray:
"""
Extracts latitudes and longitudes from the input xarray.DataArray.

Args:
inarr (xr.DataArray): Input xarray.DataArray containing spatial coordinates.
epsg (int): EPSG code for coordinate reference system.

Returns:
np.ndarray: Array containing extracted latitudes and longitudes.
"""
# EPSG:4326 is the supported crs for presto
transformer = Transformer.from_crs(f"EPSG:{epsg}", "EPSG:4326", always_xy=True)
x, y = np.meshgrid(inarr.x, inarr.y)
lon, lat = transformer.transform(x, y)

flat_latlons = rearrange(np.stack([lat, lon]), "c x y -> (x y) c")

# 2D array where each row represents a pair of latitude and longitude coordinates.
return flat_latlons

@classmethod
def _preprocess_band_values(cls, values: np.ndarray, presto_band: str) -> np.ndarray:
"""
Preprocesses the band values based on the given presto_val.

Args:
values (np.ndarray): Array of band values to preprocess.
presto_val (str): Name of the band for preprocessing.

Returns:
np.ndarray: Preprocessed array of band values.
"""
if presto_band in ["VV", "VH"]:
# Convert to dB
values = 20 * np.log10(values) - 83
elif presto_band == "total_precipitation":
# Scale precipitation and convert mm to m
values = values / (100 * 1000.0)
elif presto_band == "temperature_2m":
# Remove scaling
values = values / 100
return values

@staticmethod
def _extract_months(inarr: xr.DataArray) -> np.ndarray:
"""
Calculate the start month based on the first timestamp in the input array,
and create an array of the same length filled with that start month value.

Parameters:
- inarr: xarray.DataArray or numpy.ndarray
Input array containing timestamps.

eo_data[:, :, BANDS.index(presto_val)] = values
mask[:, :, IDX_TO_BAND_GROUPS[presto_val]] += ~idx_valid
Returns:
- months: numpy.ndarray
Array of start month values, with the same length as the input array.
"""
num_instances = len(inarr.x) * len(inarr.y)

start_month = (inarr.t.values[0].astype("datetime64[M]").astype(int) % 12 + 1) - 1

y = rearrange(ds[cls.Y].values, "t x y -> (x y) t")
# -1 because we index from 0
start_month = (ds.t.values[0].astype("datetime64[M]").astype(int) % 12 + 1) - 1
months = np.ones((num_instances)) * start_month
return months

@staticmethod
def _subset_array_temporally(inarr: xr.DataArray) -> xr.DataArray:
"""
Subset the input xarray.DataArray temporally based on `valid_time` attribute.

transformer = Transformer.from_crs(f"EPSG:{epsg_coords}", "EPSG:4326", always_xy=True)
lon, lat = transformer.transform(ds.x, ds.y)
Args:
inarr (xr.DataArray): Input xarray.DataArray containing EO data.

latlons = np.stack(
[np.repeat(lat, repeats=len(lon)), repeat(lon, "c -> (h c)", h=len(lat))],
axis=-1,
)
Returns:
xr.DataArray: Temporally subsetted xarray.DataArray.
"""

return eo_data, np.repeat(mask, BAND_EXPANSION, axis=-1), latlons, months, y
# Use valid_time attribute to extract the right part of the time series
valid_time = pd.to_datetime(inarr.attrs["valid_time"]).replace(day=1)
end_time = valid_time + pd.DateOffset(months=5)
start_time = valid_time - pd.DateOffset(months=6)
inarr = inarr.sel(t=slice(start_time, end_time))
num_timesteps = len(inarr.t)
assert num_timesteps == 12, "Expected 12 timesteps, only found {}".format(num_timesteps)

return inarr

@classmethod
def nc_to_arrays(
cls, filepath: Path
) -> Tuple[
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
]:
ds = xr.open_dataset(filepath)
epsg = CRS.from_wkt(xr.open_dataset(filepath).crs.attrs["crs_wkt"]).to_epsg()

if epsg is None:
raise ValueError("EPSG code not found in the input file.")
inarr = ds.drop("crs").to_array(dim="bands")

# Extract coordinates for reconstruction
x_coord, y_coord = inarr.x, inarr.y

# Temporal subsetting to 12 timesteps
inarr = cls._subset_array_temporally(inarr)

eo_data, mask = cls._extract_eo_data(inarr)
flat_latlons = cls._extract_latlons(inarr, epsg)
months = cls._extract_months(inarr)

if cls.Y not in ds:
target = np.ones_like(months) * cls._NODATAVALUE
else:
target = rearrange(inarr.sel(bands=cls.Y).values, "t x y -> (x y) t")

return (
eo_data,
np.repeat(mask, BAND_EXPANSION, axis=-1),
flat_latlons,
months,
target,
x_coord,
y_coord,
)

def __getitem__(self, idx):
filepath = self.all_files[idx]
eo, mask, latlons, months, y = self.nc_to_arrays(filepath)
(
eo,
mask,
flat_latlons,
months,
target,
x_coord,
y_coord,
) = self.nc_to_arrays(filepath)

dynamic_world = np.ones((eo.shape[0], eo.shape[1])) * (DynamicWorld2020_2021.class_amount)

return S1_S2_ERA5_SRTM.normalize(eo), dynamic_world, mask, latlons, months, y
return (
S1_S2_ERA5_SRTM.normalize(eo),
dynamic_world,
mask,
flat_latlons,
months,
target,
x_coord,
y_coord,
)

@staticmethod
def combine_predictions(
latlons: np.ndarray, all_preds: np.ndarray, gt: np.ndarray, ndvi: np.ndarray
) -> pd.DataFrame:
flat_lat, flat_lon = latlons[:, 0], latlons[:, 1]
all_preds: np.ndarray,
gt: np.ndarray,
ndvi: np.ndarray,
x_coord: Union[xr.DataArray, np.ndarray, List[float]],
y_coord: Union[xr.DataArray, np.ndarray, List[float]],
) -> xr.DataArray:

if len(all_preds.shape) == 1:
all_preds = np.expand_dims(all_preds, axis=-1)

data_dict: Dict[str, np.ndarray] = {"lat": flat_lat, "lon": flat_lon}
# Get band names
bands = [f"prediction_{i}" for i in range(all_preds.shape[1])] + [
"ground_truth",
"ndvi",
]

# Initialize gridded data array
data = np.empty((len(bands), len(y_coord), len(x_coord)))

# Fill with gridded predictions
for i in range(all_preds.shape[1]):
prediction_label = f"prediction_{i}"
data_dict[prediction_label] = all_preds[:, i]
data_dict["ground_truth"] = gt[:, 0]
data_dict["ndvi"] = ndvi
return pd.DataFrame(data=data_dict).set_index(["lat", "lon"])
data[i, ...] = rearrange(
all_preds[:, i], "(y x) -> 1 y x", y=len(y_coord), x=len(x_coord)
)

# Fill with ground truth and NDVI
data[-2, ...] = rearrange(gt[:, 0], "(y x) -> 1 y x", y=len(y_coord), x=len(x_coord))
data[-1, ...] = rearrange(ndvi, "(y x) -> 1 y x", y=len(y_coord), x=len(x_coord))

return xr.DataArray(coords=[bands, y_coord, x_coord], dims=["bands", "y", "x"], data=data)
25 changes: 19 additions & 6 deletions presto/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def finetune_sklearn_model(
assert model_mode in ["Regression", "Random Forest", "CatBoostClassifier"]
pretrained_model.eval()

def dataloader_to_encodings_and_targets(dl: DataLoader) -> Tuple[np.ndarray, np.ndarray]:
def dataloader_to_encodings_and_targets(
dl: DataLoader,
) -> Tuple[np.ndarray, np.ndarray]:
encoding_list, target_list = [], []
for x, y, dw, latlons, month, variable_mask in dl:
x_f, dw_f, latlons_f, month_f, variable_mask_f = [
Expand Down Expand Up @@ -184,7 +186,9 @@ def dataloader_to_encodings_and_targets(dl: DataLoader) -> Tuple[np.ndarray, np.
if model == "CatBoostClassifier":
fit_models.append(
clone(model_dict[model]).fit(
train_encodings, train_targets, eval_set=Pool(val_encodings, val_targets)
train_encodings,
train_targets,
eval_set=Pool(val_encodings, val_targets),
)
)
else:
Expand Down Expand Up @@ -245,13 +249,22 @@ def spatial_inference(
assert self.spatial_inference_savedir is not None
ds = WorldCerealInferenceDataset()
for i in range(len(ds)):
eo, dynamic_world, mask, latlons, months, y = ds[i]
(
eo,
dynamic_world,
mask,
flat_latlons,
months,
y,
x_coord,
y_coord,
) = ds[i]
dl = DataLoader(
TensorDataset(
torch.from_numpy(eo).float(),
torch.from_numpy(y.astype(np.int16)),
torch.from_numpy(dynamic_world).long(),
torch.from_numpy(latlons).float(),
torch.from_numpy(flat_latlons).float(),
torch.from_numpy(months).long(),
torch.from_numpy(mask).float(),
),
Expand All @@ -263,13 +276,13 @@ def spatial_inference(
# take the middle timestep's ndvi
middle_timestep = eo.shape[1] // 2
ndvi = eo[:, middle_timestep, NORMED_BANDS.index("NDVI")]
df = ds.combine_predictions(latlons, test_preds_np, y, ndvi)
da = ds.combine_predictions(test_preds_np, y, ndvi, x_coord, y_coord)
prefix = f"{self.name}_{ds.all_files[i].stem}"
if pretrained_model is None:
filename = f"{prefix}_finetuning.nc"
else:
filename = f"{prefix}_{finetuned_model.__class__.__name__}.nc"
df.to_xarray().to_netcdf(self.spatial_inference_savedir / filename)
da.to_netcdf(self.spatial_inference_savedir / filename)

@torch.no_grad()
def evaluate(
Expand Down
Loading
Loading